diff --git a/roomserver/internal/input.go b/roomserver/internal/input.go index ab3d7516b..932b4df46 100644 --- a/roomserver/internal/input.go +++ b/roomserver/internal/input.go @@ -60,7 +60,7 @@ func (r *RoomserverInternalAPI) InputRoomEvents( defer r.mutex.Unlock() for i := range request.InputInviteEvents { var loopback *api.InputRoomEvent - if loopback, err = processInviteEvent(ctx, r.DB, r, request.InputInviteEvents[i]); err != nil { + if loopback, err = r.processInviteEvent(ctx, r, request.InputInviteEvents[i]); err != nil { return err } // The processInviteEvent function can optionally return a @@ -71,7 +71,7 @@ func (r *RoomserverInternalAPI) InputRoomEvents( } } for i := range request.InputRoomEvents { - if response.EventID, err = processRoomEvent(ctx, r.DB, r, request.InputRoomEvents[i]); err != nil { + if response.EventID, err = r.processRoomEvent(ctx, request.InputRoomEvents[i]); err != nil { return err } } diff --git a/roomserver/internal/input_events.go b/roomserver/internal/input_events.go index f5c678ca6..a4167714d 100644 --- a/roomserver/internal/input_events.go +++ b/roomserver/internal/input_events.go @@ -31,21 +31,13 @@ import ( log "github.com/sirupsen/logrus" ) -// OutputRoomEventWriter has the APIs needed to write an event to the output logs. -type OutputRoomEventWriter interface { - // Write a list of events for a room - WriteOutputEvents(roomID string, updates []api.OutputEvent) error -} - // processRoomEvent can only be called once at a time // // TODO(#375): This should be rewritten to allow concurrent calls. The // difficulty is in ensuring that we correctly annotate events with the correct // state deltas when sending to kafka streams -func processRoomEvent( +func (r *RoomserverInternalAPI) processRoomEvent( ctx context.Context, - db storage.Database, - ow OutputRoomEventWriter, input api.InputRoomEvent, ) (eventID string, err error) { // Parse and validate the event JSON @@ -54,7 +46,7 @@ func processRoomEvent( // Check that the event passes authentication checks and work out // the numeric IDs for the auth events. - authEventNIDs, err := checkAuthEvents(ctx, db, headered, input.AuthEventIDs) + authEventNIDs, err := checkAuthEvents(ctx, r.DB, headered, input.AuthEventIDs) if err != nil { logrus.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("processRoomEvent.checkAuthEvents failed for event") return @@ -63,7 +55,7 @@ func processRoomEvent( // If we don't have a transaction ID then get one. if input.TransactionID != nil { tdID := input.TransactionID - eventID, err = db.GetTransactionEventID( + eventID, err = r.DB.GetTransactionEventID( ctx, tdID.TransactionID, tdID.SessionID, event.Sender(), ) // On error OR event with the transaction already processed/processesing @@ -73,7 +65,7 @@ func processRoomEvent( } // Store the event. - roomNID, stateAtEvent, err := db.StoreEvent(ctx, event, input.TransactionID, authEventNIDs) + roomNID, stateAtEvent, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs) if err != nil { return } @@ -93,16 +85,14 @@ func processRoomEvent( if stateAtEvent.BeforeStateSnapshotNID == 0 { // We haven't calculated a state for this event yet. // Lets calculate one. - err = calculateAndSetState(ctx, db, input, roomNID, &stateAtEvent, event) + err = r.calculateAndSetState(ctx, input, roomNID, &stateAtEvent, event) if err != nil { return } } - if err = updateLatestEvents( + if err = r.updateLatestEvents( ctx, // context - db, // roomserver database - ow, // output event writer roomNID, // room NID to update stateAtEvent, // state at event (below) event, // event @@ -116,29 +106,36 @@ func processRoomEvent( return event.EventID(), nil } -func calculateAndSetState( +func (r *RoomserverInternalAPI) calculateAndSetState( ctx context.Context, - db storage.Database, input api.InputRoomEvent, roomNID types.RoomNID, stateAtEvent *types.StateAtEvent, event gomatrixserverlib.Event, ) error { var err error - roomState := state.NewStateResolution(db) + roomState := state.NewStateResolution(r.DB) if input.HasState { - // TODO: Check here if we think we're in the room already. + // Check here if we think we're in the room already. stateAtEvent.Overwrite = true + var joinEventNIDs []types.EventNID + // Request join memberships only for local users only. + if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true, true); err == nil { + // If we have no local users that are joined to the room then any state about + // the room that we have is quite possibly out of date. Therefore in that case + // we should overwrite it rather than merge it. + stateAtEvent.Overwrite = len(joinEventNIDs) == 0 + } // We've been told what the state at the event is so we don't need to calculate it. // Check that those state events are in the database and store the state. var entries []types.StateEntry - if entries, err = db.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { + if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { return err } - if stateAtEvent.BeforeStateSnapshotNID, err = db.AddState(ctx, roomNID, nil, entries); err != nil { + if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil { return err } } else { @@ -149,12 +146,11 @@ func calculateAndSetState( return err } } - return db.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) + return r.DB.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) } -func processInviteEvent( +func (r *RoomserverInternalAPI) processInviteEvent( ctx context.Context, - db storage.Database, ow *RoomserverInternalAPI, input api.InputInviteEvent, ) (*api.InputRoomEvent, error) { @@ -172,7 +168,10 @@ func processInviteEvent( "target_user_id": targetUserID, }).Info("processing invite event") - updater, err := db.MembershipUpdater(ctx, roomID, targetUserID, input.RoomVersion) + _, domain, _ := gomatrixserverlib.SplitID('@', targetUserID) + isTargetLocalUser := domain == r.Cfg.Matrix.ServerName + + updater, err := r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocalUser, input.RoomVersion) if err != nil { return nil, err } @@ -239,7 +238,7 @@ func processInviteEvent( // up from local data (which is most likely to be if the event came // from the CS API). If we know about the room then we can insert // the invite room state, if we don't then we just fail quietly. - if irs, ierr := buildInviteStrippedState(ctx, db, input); ierr == nil { + if irs, ierr := buildInviteStrippedState(ctx, r.DB, input); ierr == nil { if err = event.SetUnsignedField("invite_room_state", irs); err != nil { return nil, err } diff --git a/roomserver/internal/input_latest_events.go b/roomserver/internal/input_latest_events.go index 6eeeedab0..d7c9a5cb6 100644 --- a/roomserver/internal/input_latest_events.go +++ b/roomserver/internal/input_latest_events.go @@ -23,7 +23,6 @@ import ( "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" - "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -46,17 +45,15 @@ import ( // 7 <----- latest // // Can only be called once at a time -func updateLatestEvents( +func (r *RoomserverInternalAPI) updateLatestEvents( ctx context.Context, - db storage.Database, - ow OutputRoomEventWriter, roomNID types.RoomNID, stateAtEvent types.StateAtEvent, event gomatrixserverlib.Event, sendAsServer string, transactionID *api.TransactionID, ) (err error) { - updater, err := db.GetLatestEventsForUpdate(ctx, roomNID) + updater, err := r.DB.GetLatestEventsForUpdate(ctx, roomNID) if err != nil { return } @@ -70,9 +67,8 @@ func updateLatestEvents( u := latestEventsUpdater{ ctx: ctx, - db: db, + api: r, updater: updater, - ow: ow, roomNID: roomNID, stateAtEvent: stateAtEvent, event: event, @@ -94,9 +90,8 @@ func updateLatestEvents( // when there are so many variables to pass around. type latestEventsUpdater struct { ctx context.Context - db storage.Database + api *RoomserverInternalAPI updater types.RoomRecentEventsUpdater - ow OutputRoomEventWriter roomNID types.RoomNID stateAtEvent types.StateAtEvent event gomatrixserverlib.Event @@ -181,7 +176,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { // If we need to generate any output events then here's where we do it. // TODO: Move this! - updates, err := updateMemberships(u.ctx, u.db, u.updater, u.removed, u.added) + updates, err := u.api.updateMemberships(u.ctx, u.updater, u.removed, u.added) if err != nil { return err } @@ -200,7 +195,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { // send the event asynchronously but we would need to ensure that 1) the events are written to the log in // the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the // necessary bookkeeping we'll keep the event sending synchronous for now. - if err = u.ow.WriteOutputEvents(u.event.RoomID(), updates); err != nil { + if err = u.api.WriteOutputEvents(u.event.RoomID(), updates); err != nil { return err } @@ -213,7 +208,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { func (u *latestEventsUpdater) latestState() error { var err error - roomState := state.NewStateResolution(u.db) + roomState := state.NewStateResolution(u.api.DB) // Get a list of the current latest events. latestStateAtEvents := make([]types.StateAtEvent, len(u.latest)) @@ -303,7 +298,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) latestEventIDs[i] = u.latest[i].EventID } - roomVersion, err := u.db.GetRoomVersionForRoom(u.ctx, u.event.RoomID()) + roomVersion, err := u.api.DB.GetRoomVersionForRoom(u.ctx, u.event.RoomID()) if err != nil { return nil, err } @@ -329,7 +324,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) stateEventNIDs = append(stateEventNIDs, entry.EventNID) } stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))] - eventIDMap, err := u.db.EventIDs(u.ctx, stateEventNIDs) + eventIDMap, err := u.api.DB.EventIDs(u.ctx, stateEventNIDs) if err != nil { return nil, err } diff --git a/roomserver/internal/input_membership.go b/roomserver/internal/input_membership.go index 666e7ebcc..af0c7f8b3 100644 --- a/roomserver/internal/input_membership.go +++ b/roomserver/internal/input_membership.go @@ -19,7 +19,6 @@ import ( "fmt" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -28,9 +27,8 @@ import ( // user affected by a change in the current state of the room. // Returns a list of output events to write to the kafka log to inform the // consumers about the invites added or retired by the change in current state. -func updateMemberships( +func (r *RoomserverInternalAPI) updateMemberships( ctx context.Context, - db storage.Database, updater types.RoomRecentEventsUpdater, removed, added []types.StateEntry, ) ([]api.OutputEvent, error) { @@ -48,7 +46,7 @@ func updateMemberships( // Load the event JSON so we can look up the "membership" key. // TODO: Maybe add a membership key to the events table so we can load that // key without having to load the entire event JSON? - events, err := db.Events(ctx, eventNIDs) + events, err := r.DB.Events(ctx, eventNIDs) if err != nil { return nil, err } @@ -71,15 +69,16 @@ func updateMemberships( ae = &ev.Event } } - if updates, err = updateMembership(updater, targetUserNID, re, ae, updates); err != nil { + if updates, err = r.updateMembership(updater, targetUserNID, re, ae, updates); err != nil { return nil, err } } return updates, nil } -func updateMembership( - updater types.RoomRecentEventsUpdater, targetUserNID types.EventStateKeyNID, +func (r *RoomserverInternalAPI) updateMembership( + updater types.RoomRecentEventsUpdater, + targetUserNID types.EventStateKeyNID, remove, add *gomatrixserverlib.Event, updates []api.OutputEvent, ) ([]api.OutputEvent, error) { @@ -113,7 +112,7 @@ func updateMembership( return updates, nil } - mu, err := updater.MembershipUpdater(targetUserNID) + mu, err := updater.MembershipUpdater(targetUserNID, r.isLocalTarget(add)) if err != nil { return nil, err } @@ -132,6 +131,15 @@ func updateMembership( } } +func (r *RoomserverInternalAPI) isLocalTarget(event *gomatrixserverlib.Event) bool { + isTargetLocalUser := false + if statekey := event.StateKey(); statekey != nil { + _, domain, _ := gomatrixserverlib.SplitID('@', *statekey) + isTargetLocalUser = domain == r.Cfg.Matrix.ServerName + } + return isTargetLocalUser +} + func updateToInviteMembership( mu types.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, roomVersion gomatrixserverlib.RoomVersion, diff --git a/roomserver/internal/query.go b/roomserver/internal/query.go index 2d1c21c57..fce2ae907 100644 --- a/roomserver/internal/query.go +++ b/roomserver/internal/query.go @@ -267,7 +267,7 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom( var stateEntries []types.StateEntry if stillInRoom { var eventNIDs []types.EventNID - eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, request.JoinedOnly) + eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, request.JoinedOnly, false) if err != nil { return err } @@ -591,7 +591,7 @@ func (r *RoomserverInternalAPI) isServerCurrentlyInRoom(ctx context.Context, ser return false, err } - eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true) + eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true, false) if err != nil { return false, err } diff --git a/roomserver/internal/query_backfill.go b/roomserver/internal/query_backfill.go index 49e0af34a..23ae9455a 100644 --- a/roomserver/internal/query_backfill.go +++ b/roomserver/internal/query_backfill.go @@ -297,7 +297,7 @@ func joinEventsFromHistoryVisibility( if err != nil { return nil, err } - joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, roomNID, true) + joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, roomNID, true, false) if err != nil { return nil, err } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index fb39eca63..1e0232d20 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -83,9 +83,9 @@ type Database interface { GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) GetCreatorIDForAlias(ctx context.Context, alias string) (string, error) RemoveRoomAlias(ctx context.Context, alias string) error - MembershipUpdater(ctx context.Context, roomID, targetUserID string, roomVersion gomatrixserverlib.RoomVersion) (types.MembershipUpdater, error) + MembershipUpdater(ctx context.Context, roomID, targetUserID string, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion) (types.MembershipUpdater, error) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error) - GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error) + GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool) ([]types.EventNID, error) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) GetRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) } diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 9c8a4c259..820ef4e71 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -59,6 +59,10 @@ CREATE TABLE IF NOT EXISTS roomserver_membership ( -- This NID is updated if the join event gets updated (e.g. profile update), -- or if the user leaves/joins the room. event_nid BIGINT NOT NULL DEFAULT 0, + -- Local target is true if the target_nid refers to a local user rather than + -- a federated one. This is an optimisation for resetting state on federated + -- room joins. + target_local BOOLEAN NOT NULL DEFAULT false, UNIQUE (room_nid, target_nid) ); ` @@ -66,8 +70,8 @@ CREATE TABLE IF NOT EXISTS roomserver_membership ( // Insert a row in to membership table so that it can be locked by the // SELECT FOR UPDATE const insertMembershipSQL = "" + - "INSERT INTO roomserver_membership (room_nid, target_nid)" + - " VALUES ($1, $2)" + + "INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" + + " VALUES ($1, $2, $3)" + " ON CONFLICT DO NOTHING" const selectMembershipFromRoomAndTargetSQL = "" + @@ -78,10 +82,20 @@ const selectMembershipsFromRoomAndMembershipSQL = "" + "SELECT event_nid FROM roomserver_membership" + " WHERE room_nid = $1 AND membership_nid = $2" +const selectLocalMembershipsFromRoomAndMembershipSQL = "" + + "SELECT event_nid FROM roomserver_membership" + + " WHERE room_nid = $1 AND membership_nid = $2" + + " AND target_local = true" + const selectMembershipsFromRoomSQL = "" + "SELECT event_nid FROM roomserver_membership" + " WHERE room_nid = $1" +const selectLocalMembershipsFromRoomSQL = "" + + "SELECT event_nid FROM roomserver_membership" + + " WHERE room_nid = $1" + + " AND target_local = true" + const selectMembershipForUpdateSQL = "" + "SELECT membership_nid FROM roomserver_membership" + " WHERE room_nid = $1 AND target_nid = $2 FOR UPDATE" @@ -91,12 +105,14 @@ const updateMembershipSQL = "" + " WHERE room_nid = $1 AND target_nid = $2" type membershipStatements struct { - insertMembershipStmt *sql.Stmt - selectMembershipForUpdateStmt *sql.Stmt - selectMembershipFromRoomAndTargetStmt *sql.Stmt - selectMembershipsFromRoomAndMembershipStmt *sql.Stmt - selectMembershipsFromRoomStmt *sql.Stmt - updateMembershipStmt *sql.Stmt + insertMembershipStmt *sql.Stmt + selectMembershipForUpdateStmt *sql.Stmt + selectMembershipFromRoomAndTargetStmt *sql.Stmt + selectMembershipsFromRoomAndMembershipStmt *sql.Stmt + selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt + selectMembershipsFromRoomStmt *sql.Stmt + selectLocalMembershipsFromRoomStmt *sql.Stmt + updateMembershipStmt *sql.Stmt } func (s *membershipStatements) prepare(db *sql.DB) (err error) { @@ -110,7 +126,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, {&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL}, + {&s.selectLocalMembershipsFromRoomAndMembershipStmt, selectLocalMembershipsFromRoomAndMembershipSQL}, {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, + {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, }.prepare(db) } @@ -118,9 +136,10 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { func (s *membershipStatements) insertMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + localTarget bool, ) error { stmt := common.TxStmt(txn, s.insertMembershipStmt) - _, err := stmt.ExecContext(ctx, roomNID, targetUserNID) + _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget) return err } @@ -145,9 +164,15 @@ func (s *membershipStatements) selectMembershipFromRoomAndTarget( } func (s *membershipStatements) selectMembershipsFromRoom( - ctx context.Context, roomNID types.RoomNID, + ctx context.Context, roomNID types.RoomNID, localOnly bool, ) (eventNIDs []types.EventNID, err error) { - rows, err := s.selectMembershipsFromRoomStmt.QueryContext(ctx, roomNID) + var stmt *sql.Stmt + if localOnly { + stmt = s.selectLocalMembershipsFromRoomStmt + } else { + stmt = s.selectMembershipsFromRoomStmt + } + rows, err := stmt.QueryContext(ctx, roomNID) if err != nil { return } @@ -165,10 +190,16 @@ func (s *membershipStatements) selectMembershipsFromRoom( func (s *membershipStatements) selectMembershipsFromRoomAndMembership( ctx context.Context, - roomNID types.RoomNID, membership membershipState, + roomNID types.RoomNID, membership membershipState, localOnly bool, ) (eventNIDs []types.EventNID, err error) { - stmt := s.selectMembershipsFromRoomAndMembershipStmt - rows, err := stmt.QueryContext(ctx, roomNID, membership) + var rows *sql.Rows + var stmt *sql.Stmt + if localOnly { + stmt = s.selectLocalMembershipsFromRoomAndMembershipStmt + } else { + stmt = s.selectMembershipsFromRoomAndMembershipStmt + } + rows, err = stmt.QueryContext(ctx, roomNID, membership) if err != nil { return } diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 1d825ecc2..d451d6650 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -459,8 +459,8 @@ func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error return u.d.statements.updateEventSentToOutput(u.ctx, u.txn, eventNID) } -func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (types.MembershipUpdater, error) { - return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID) +func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (types.MembershipUpdater, error) { + return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID, targetLocal) } // RoomNID implements query.RoomserverQueryAPIDB @@ -558,7 +558,7 @@ func (d *Database) StateEntriesForTuples( // MembershipUpdater implements input.RoomEventDatabase func (d *Database) MembershipUpdater( ctx context.Context, roomID, targetUserID string, - roomVersion gomatrixserverlib.RoomVersion, + targetLocal bool, roomVersion gomatrixserverlib.RoomVersion, ) (types.MembershipUpdater, error) { txn, err := d.db.Begin() if err != nil { @@ -581,7 +581,7 @@ func (d *Database) MembershipUpdater( return nil, err } - updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID) + updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID, targetLocal) if err != nil { return nil, err } @@ -603,9 +603,10 @@ func (d *Database) membershipUpdaterTxn( txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + targetLocal bool, ) (types.MembershipUpdater, error) { - if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil { + if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { return nil, err } @@ -748,15 +749,15 @@ func (d *Database) GetMembership( // GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB func (d *Database) GetMembershipEventNIDsForRoom( - ctx context.Context, roomNID types.RoomNID, joinOnly bool, + ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, ) ([]types.EventNID, error) { if joinOnly { return d.statements.selectMembershipsFromRoomAndMembership( - ctx, roomNID, membershipStateJoin, + ctx, roomNID, membershipStateJoin, localOnly, ) } - return d.statements.selectMembershipsFromRoom(ctx, roomNID) + return d.statements.selectMembershipsFromRoom(ctx, roomNID, localOnly) } // EventsFromIDs implements query.RoomserverQueryAPIEventDB diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 7ae28e4b8..ca4d8fbe9 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -38,6 +38,7 @@ const membershipSchema = ` sender_nid INTEGER NOT NULL DEFAULT 0, membership_nid INTEGER NOT NULL DEFAULT 1, event_nid INTEGER NOT NULL DEFAULT 0, + target_local BOOLEAN NOT NULL DEFAULT false, UNIQUE (room_nid, target_nid) ); ` @@ -45,8 +46,8 @@ const membershipSchema = ` // Insert a row in to membership table so that it can be locked by the // SELECT FOR UPDATE const insertMembershipSQL = "" + - "INSERT INTO roomserver_membership (room_nid, target_nid)" + - " VALUES ($1, $2)" + + "INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" + + " VALUES ($1, $2, $3)" + " ON CONFLICT DO NOTHING" const selectMembershipFromRoomAndTargetSQL = "" + @@ -57,10 +58,20 @@ const selectMembershipsFromRoomAndMembershipSQL = "" + "SELECT event_nid FROM roomserver_membership" + " WHERE room_nid = $1 AND membership_nid = $2" +const selectLocalMembershipsFromRoomAndMembershipSQL = "" + + "SELECT event_nid FROM roomserver_membership" + + " WHERE room_nid = $1 AND membership_nid = $2" + + " AND target_local = true" + const selectMembershipsFromRoomSQL = "" + "SELECT event_nid FROM roomserver_membership" + " WHERE room_nid = $1" +const selectLocalMembershipsFromRoomSQL = "" + + "SELECT event_nid FROM roomserver_membership" + + " WHERE room_nid = $1" + + " AND target_local = true" + const selectMembershipForUpdateSQL = "" + "SELECT membership_nid FROM roomserver_membership" + " WHERE room_nid = $1 AND target_nid = $2" @@ -70,12 +81,14 @@ const updateMembershipSQL = "" + " WHERE room_nid = $4 AND target_nid = $5" type membershipStatements struct { - insertMembershipStmt *sql.Stmt - selectMembershipForUpdateStmt *sql.Stmt - selectMembershipFromRoomAndTargetStmt *sql.Stmt - selectMembershipsFromRoomAndMembershipStmt *sql.Stmt - selectMembershipsFromRoomStmt *sql.Stmt - updateMembershipStmt *sql.Stmt + insertMembershipStmt *sql.Stmt + selectMembershipForUpdateStmt *sql.Stmt + selectMembershipFromRoomAndTargetStmt *sql.Stmt + selectMembershipsFromRoomAndMembershipStmt *sql.Stmt + selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt + selectMembershipsFromRoomStmt *sql.Stmt + selectLocalMembershipsFromRoomStmt *sql.Stmt + updateMembershipStmt *sql.Stmt } func (s *membershipStatements) prepare(db *sql.DB) (err error) { @@ -89,7 +102,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, {&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL}, + {&s.selectLocalMembershipsFromRoomAndMembershipStmt, selectLocalMembershipsFromRoomAndMembershipSQL}, {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, + {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, }.prepare(db) } @@ -97,9 +112,10 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { func (s *membershipStatements) insertMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + localTarget bool, ) error { stmt := common.TxStmt(txn, s.insertMembershipStmt) - _, err := stmt.ExecContext(ctx, roomNID, targetUserNID) + _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget) return err } @@ -127,9 +143,14 @@ func (s *membershipStatements) selectMembershipFromRoomAndTarget( func (s *membershipStatements) selectMembershipsFromRoom( ctx context.Context, txn *sql.Tx, - roomNID types.RoomNID, + roomNID types.RoomNID, localOnly bool, ) (eventNIDs []types.EventNID, err error) { - selectStmt := common.TxStmt(txn, s.selectMembershipsFromRoomStmt) + var selectStmt *sql.Stmt + if localOnly { + selectStmt = common.TxStmt(txn, s.selectLocalMembershipsFromRoomStmt) + } else { + selectStmt = common.TxStmt(txn, s.selectMembershipsFromRoomStmt) + } rows, err := selectStmt.QueryContext(ctx, roomNID) if err != nil { return nil, err @@ -145,11 +166,17 @@ func (s *membershipStatements) selectMembershipsFromRoom( } return } + func (s *membershipStatements) selectMembershipsFromRoomAndMembership( ctx context.Context, txn *sql.Tx, - roomNID types.RoomNID, membership membershipState, + roomNID types.RoomNID, membership membershipState, localOnly bool, ) (eventNIDs []types.EventNID, err error) { - stmt := common.TxStmt(txn, s.selectMembershipsFromRoomAndMembershipStmt) + var stmt *sql.Stmt + if localOnly { + stmt = common.TxStmt(txn, s.selectLocalMembershipsFromRoomAndMembershipStmt) + } else { + stmt = common.TxStmt(txn, s.selectMembershipsFromRoomAndMembershipStmt) + } rows, err := stmt.QueryContext(ctx, roomNID, membership) if err != nil { return diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index e77fea9cf..209922fa2 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -569,9 +569,9 @@ func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error return err } -func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (mu types.MembershipUpdater, err error) { +func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (mu types.MembershipUpdater, err error) { err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { - mu, err = u.d.membershipUpdaterTxn(u.ctx, txn, u.roomNID, targetUserNID) + mu, err = u.d.membershipUpdaterTxn(u.ctx, txn, u.roomNID, targetUserNID, targetLocal) return err }) return @@ -680,7 +680,7 @@ func (d *Database) StateEntriesForTuples( // MembershipUpdater implements input.RoomEventDatabase func (d *Database) MembershipUpdater( ctx context.Context, roomID, targetUserID string, - roomVersion gomatrixserverlib.RoomVersion, + targetLocal bool, roomVersion gomatrixserverlib.RoomVersion, ) (updater types.MembershipUpdater, err error) { var txn *sql.Tx txn, err = d.db.Begin() @@ -716,7 +716,7 @@ func (d *Database) MembershipUpdater( return nil, err } - updater, err = d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID) + updater, err = d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID, targetLocal) if err != nil { return nil, err } @@ -738,9 +738,10 @@ func (d *Database) membershipUpdaterTxn( txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + targetLocal bool, ) (types.MembershipUpdater, error) { - if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil { + if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { return nil, err } @@ -896,17 +897,17 @@ func (d *Database) GetMembership( // GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB func (d *Database) GetMembershipEventNIDsForRoom( - ctx context.Context, roomNID types.RoomNID, joinOnly bool, + ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, ) (eventNIDs []types.EventNID, err error) { err = common.WithTransaction(d.db, func(txn *sql.Tx) error { if joinOnly { eventNIDs, err = d.statements.selectMembershipsFromRoomAndMembership( - ctx, txn, roomNID, membershipStateJoin, + ctx, txn, roomNID, membershipStateJoin, localOnly, ) return nil } - eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID) + eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID, localOnly) return nil }) return diff --git a/roomserver/types/types.go b/roomserver/types/types.go index da83f614c..74e6b0784 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -172,7 +172,7 @@ type RoomRecentEventsUpdater interface { MarkEventAsSent(eventNID EventNID) error // Build a membership updater for the target user in this room. // It will share the same transaction as this updater. - MembershipUpdater(targetUserNID EventStateKeyNID) (MembershipUpdater, error) + MembershipUpdater(targetUserNID EventStateKeyNID, isTargetLocalUser bool) (MembershipUpdater, error) // Implements Transaction so it can be committed or rolledback common.Transaction }