diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/membership_table.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/membership_table.go index be09ecc33..39fa0ca0f 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/membership_table.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/membership_table.go @@ -25,6 +25,8 @@ CREATE TABLE IF NOT EXISTS memberships ( localpart TEXT NOT NULL, -- The room this user is a member of room_id TEXT NOT NULL, + -- The ID of the join membership event + event_id TEXT NOT NULL, -- A user can only be member of a room once PRIMARY KEY (localpart, room_id) @@ -32,20 +34,28 @@ CREATE TABLE IF NOT EXISTS memberships ( ` const insertMembershipSQL = "" + - "INSERT INTO memberships(localpart, room_id) VALUES ($1, $2)" + "INSERT INTO memberships(localpart, room_id, event_id) VALUES ($1, $2, $3)" const selectMembershipSQL = "" + "SELECT * from memberships WHERE localpart = $1 AND room_id = $2" +const selectMembershipByEventIDSQL = "" + + "SELECT localpart, room_id FROM memberships WHERE event_id = $1" + const selectMembershipsByLocalpartSQL = "" + "SELECT room_id FROM memberships WHERE localpart = $1" const deleteMembershipSQL = "" + "DELETE FROM memberships WHERE localpart = $1 AND room_id = $2" +const deleteMembershipByEventIDSQL = "" + + "DELETE FROM memberships WHERE event_id = $1" + type membershipStatements struct { + deleteMembershipByEventIDStmt *sql.Stmt deleteMembershipStmt *sql.Stmt insertMembershipStmt *sql.Stmt + selectMembershipByEventIDStmt *sql.Stmt selectMembershipsByLocalpartStmt *sql.Stmt selectMembershipStmt *sql.Stmt } @@ -55,12 +65,18 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { if err != nil { return } + if s.deleteMembershipByEventIDStmt, err = db.Prepare(deleteMembershipByEventIDSQL); err != nil { + return + } if s.deleteMembershipStmt, err = db.Prepare(deleteMembershipSQL); err != nil { return } if s.insertMembershipStmt, err = db.Prepare(insertMembershipSQL); err != nil { return } + if s.selectMembershipByEventIDStmt, err = db.Prepare(selectMembershipByEventIDSQL); err != nil { + return + } if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil { return } @@ -70,8 +86,8 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { return } -func (s *membershipStatements) insertMembership(localpart string, roomID string) (err error) { - _, err = s.insertMembershipStmt.Exec(localpart, roomID) +func (s *membershipStatements) insertMembership(localpart string, roomID string, eventID string) (err error) { + _, err = s.insertMembershipStmt.Exec(localpart, roomID, eventID) return } @@ -79,3 +95,16 @@ func (s *membershipStatements) deleteMembership(localpart string, roomID string) _, err = s.deleteMembershipStmt.Exec(localpart, roomID) return } + +func (s *membershipStatements) deleteMembershipByEventID(eventID string) (err error) { + _, err = s.deleteMembershipByEventIDStmt.Exec(eventID) + return +} + +func (s *membershipStatements) selectMembershipByEventID(eventID string) (localpart string, roomID string, err error) { + err = s.selectMembershipByEventIDStmt.QueryRow(eventID).Scan(&localpart, &roomID) + if err == sql.ErrNoRows { + return "", "", nil + } + return +} diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go index c6639f884..a8631b201 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go @@ -115,17 +115,26 @@ func (d *Database) SetPartitionOffset(topic string, partition int32, offset int6 } // SaveMembership saves the user matching a given localpart as a member of a given -// room. If a membership already exists between the user and the room, or of the +// room. It also stores the ID of the `join` membership event. +// If a membership already exists between the user and the room, or of the // insert fails, returns the SQL error -func (d *Database) SaveMembership(localpart string, roomID string) error { - return d.memberships.insertMembership(localpart, roomID) +func (d *Database) SaveMembership(localpart string, roomID string, eventID string) error { + return d.memberships.insertMembership(localpart, roomID, eventID) } -// RemoveMembership removes the membership of the user mathing a given localpart -// from a given room. +// RemoveMembership removes the membership of which the `join` membership event +// ID matches with the given event ID. // If the removal fails, or if there is no membership to remove, returns an error -func (d *Database) RemoveMembership(localpart string, roomID string) error { - return d.memberships.deleteMembership(localpart, roomID) +func (d *Database) RemoveMembership(eventID string) error { + return d.memberships.deleteMembershipByEventID(eventID) +} + +// GetMembershipByEventID returns the membership (as a user localpart and a room ID) +// for which the `join` membership event ID matches a given event ID +// If no membership match this event ID, the localpart and room ID will be empty strings +// If an error happens during the retrieval, returns the SQL error +func (d *Database) GetMembershipByEventID(eventID string) (string, string, error) { + return d.memberships.selectMembershipByEventID(eventID) } func hashPassword(plaintext string) (hash string, err error) { diff --git a/src/github.com/matrix-org/dendrite/clientapi/consumers/roomserver.go b/src/github.com/matrix-org/dendrite/clientapi/consumers/roomserver.go index 6173e2077..9490ec1b4 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/consumers/roomserver.go +++ b/src/github.com/matrix-org/dendrite/clientapi/consumers/roomserver.go @@ -89,6 +89,75 @@ func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error { "type": ev.Type(), }).Info("received event from roomserver") + events, err := s.lookupStateEvents(output.NewRoomEvent.AddsStateEventIDs, ev) + if err != nil { + return err + } + + for _, event := range events { + if err := s.updateMembership(event); err != nil { + return err + } + } + + for _, id := range output.NewRoomEvent.RemovesStateEventIDs { + if err := s.db.RemoveMembership(id); err != nil { + return err + } + } + + return nil +} + +// lookupStateEvents looks up the state events that are added by a new event. +func (s *OutputRoomEvent) lookupStateEvents( + addsStateEventIDs []string, event gomatrixserverlib.Event, +) ([]gomatrixserverlib.Event, error) { + // Fast path if there aren't any new state events. + if len(addsStateEventIDs) == 0 { + return nil, nil + } + + // Fast path if the only state event added is the event itself. + if len(addsStateEventIDs) == 1 && addsStateEventIDs[0] == event.EventID() { + return []gomatrixserverlib.Event{event}, nil + } + + result := []gomatrixserverlib.Event{} + missing := []string{} + for _, id := range addsStateEventIDs { + // Check if the event is already known + localpart, server, err := s.db.GetMembershipByEventID(id) + if err != nil { + return nil, err + } + + // Append the ID to the list to request so if it isn't in the database + if len(localpart) == 0 && len(server) == 0 { + missing = append(missing, id) + } + + // Append the current event in the results if its ID is in the events list + if id == event.EventID() { + result = append(result, event) + } + } + + // At this point the missing events are neither the event itself nor are + // they present in our local database. Our only option is to fetch them + // from the roomserver using the query API. + eventReq := api.QueryEventsByIDRequest{EventIDs: missing} + var eventResp api.QueryEventsByIDResponse + if err := s.query.QueryEventsByID(&eventReq, &eventResp); err != nil { + return nil, err + } + + result = append(result, eventResp.Events...) + + return result, nil +} + +func (s *OutputRoomEvent) updateMembership(ev gomatrixserverlib.Event) error { if ev.Type() == "m.room.member" && ev.StateKey() != nil { localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey()) if err != nil { @@ -100,6 +169,7 @@ func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error { return nil } + eventID := ev.EventID() roomID := ev.RoomID() membership, err := ev.Membership() if err != nil { @@ -108,16 +178,15 @@ func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error { switch membership { case "join": - if err := s.db.SaveMembership(localpart, roomID); err != nil { + if err := s.db.SaveMembership(localpart, roomID, eventID); err != nil { return err } case "leave": case "ban": - if err := s.db.RemoveMembership(localpart, roomID); err != nil { + if err := s.db.RemoveMembership(eventID); err != nil { return err } } } - return nil }