From 29deb18f9f7bc4aef6ca026e02ea907ca368d2d6 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 17 Jul 2017 17:44:24 +0100 Subject: [PATCH] Moved memberships update to the database package --- .../auth/storage/accounts/storage.go | 90 ++++++++++++++----- .../clientapi/consumers/roomserver.go | 52 +---------- 2 files changed, 69 insertions(+), 73 deletions(-) diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go index 909c4e7a5..c7fa0e474 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go @@ -32,6 +32,7 @@ type Database struct { accounts accountsStatements profiles profilesStatements memberships membershipStatements + serverName gomatrixserverlib.ServerName } // NewDatabase creates a new accounts and profiles database @@ -57,7 +58,7 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) if err = m.prepare(db); err != nil { return nil, err } - return &Database{db, partitions, a, p, m}, nil + return &Database{db, partitions, a, p, m, serverName}, nil } // GetAccountByPassword returns the account associated with the given localpart and password. @@ -122,36 +123,81 @@ func (d *Database) SaveMembership(localpart string, roomID string, eventID strin return d.memberships.insertMembership(localpart, roomID, eventID, txn) } -// RemoveMembershipsByEventIDs removes the memberships of which the `join` membership +// removeMembershipsByEventIDs removes the memberships of which the `join` membership // event ID is included in a given array of events IDs // If the removal fails, or if there is no membership to remove, returns an error -func (d *Database) RemoveMembershipsByEventIDs(eventIDs []string, txn *sql.Tx) error { +func (d *Database) removeMembershipsByEventIDs(eventIDs []string, txn *sql.Tx) error { return d.memberships.deleteMembershipsByEventIDs(eventIDs, txn) } -// StartTransaction begins a new SQL transaction and returns it -// If there was an error during the transaction initialisation, returns it -func (d *Database) StartTransaction() (*sql.Tx, error) { - return d.db.Begin() -} - -// EndTransation is called at the end of a transaction started with StartTransaction -// If called with an error, the transaction will rollback, if called with nil the -// transaction will commit -// If there was an error during either rollback or commit, returns it -func (d *Database) EndTransation(txn *sql.Tx, err error) error { +// UpdateMemberships adds the "join" membership events included in a given state +// events array, and removes those which ID is included in a given array of events +// IDs. All of the process is run in a transaction, which commits only once/if every +// insertion and deletion has been successfully processed. +// Returns a SQL error if there was an issue with any part of the process +func (d *Database) UpdateMemberships(eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error { + txn, err := d.db.Begin() if err != nil { - return txn.Rollback() + if e := txn.Rollback(); e != nil { + return e + } + return err } - return txn.Commit() + + if err := d.removeMembershipsByEventIDs(idsToRemove, txn); err != nil { + if e := txn.Rollback(); e != nil { + return e + } + return err + } + + for _, event := range eventsToAdd { + if err := d.newMembership(event, txn); err != nil { + if e := txn.Rollback(); e != nil { + return e + } + return err + } + } + + if err := txn.Commit(); err != nil { + return err + } + + return nil } -// GetMembershipByEventID returns the membership (as a user localpart and a room ID) -// for which the `join` membership event ID matches a given event ID -// If no membership match this event ID, the localpart and room ID will be empty strings -// If an error happens during the retrieval, returns the SQL error -func (d *Database) GetMembershipByEventID(eventID string) (string, string, error) { - return d.memberships.selectMembershipByEventID(eventID) +// newMembership will save a new membership in the database if the given state +// event is a "join" membership event +// If the event isn't a "join" membership event, does nothing +// If an error occurred, returns it +func (d *Database) newMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error { + if ev.Type() == "m.room.member" && ev.StateKey() != nil { + localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey()) + if err != nil { + return err + } + + // We only want state events from local users + if string(serverName) != string(d.serverName) { + return nil + } + + eventID := ev.EventID() + roomID := ev.RoomID() + membership, err := ev.Membership() + if err != nil { + return err + } + + // Only "join" membership events can be considered as new memberships + if membership == "join" { + if err := d.SaveMembership(localpart, roomID, eventID, txn); err != nil { + return err + } + } + } + return nil } func hashPassword(plaintext string) (hash string, err error) { diff --git a/src/github.com/matrix-org/dendrite/clientapi/consumers/roomserver.go b/src/github.com/matrix-org/dendrite/clientapi/consumers/roomserver.go index ecad75ed2..98dcd5b65 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/consumers/roomserver.go +++ b/src/github.com/matrix-org/dendrite/clientapi/consumers/roomserver.go @@ -15,7 +15,6 @@ package consumers import ( - "database/sql" "encoding/json" log "github.com/Sirupsen/logrus" @@ -95,25 +94,7 @@ func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error { return err } - txn, err := s.db.StartTransaction() - if err != nil { - return err - } - - if err := s.db.RemoveMembershipsByEventIDs(output.NewRoomEvent.RemovesStateEventIDs, txn); err != nil { - if e := s.db.EndTransation(txn, err); e != nil { - return e - } - return err - } - - for _, event := range events { - if err := s.updateMembership(event, txn); err != nil { - return err - } - } - - if err := s.db.EndTransation(txn, nil); err != nil { + if err := s.db.UpdateMemberships(events, output.NewRoomEvent.RemovesStateEventIDs); err != nil { return err } @@ -158,34 +139,3 @@ func (s *OutputRoomEvent) lookupStateEvents( return result, nil } - -func (s *OutputRoomEvent) updateMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error { - if ev.Type() == "m.room.member" && ev.StateKey() != nil { - localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey()) - if err != nil { - return err - } - - // we only want state events from local users - if string(serverName) != s.serverName { - return nil - } - - eventID := ev.EventID() - roomID := ev.RoomID() - membership, err := ev.Membership() - if err != nil { - return err - } - - if membership == "join" { - if err := s.db.SaveMembership(localpart, roomID, eventID, txn); err != nil { - if e := s.db.EndTransation(txn, err); e != nil { - return e - } - return err - } - } - } - return nil -}