diff --git a/roomserver/internal/input.go b/roomserver/internal/input.go index ab3d7516b..932b4df46 100644 --- a/roomserver/internal/input.go +++ b/roomserver/internal/input.go @@ -60,7 +60,7 @@ func (r *RoomserverInternalAPI) InputRoomEvents( defer r.mutex.Unlock() for i := range request.InputInviteEvents { var loopback *api.InputRoomEvent - if loopback, err = processInviteEvent(ctx, r.DB, r, request.InputInviteEvents[i]); err != nil { + if loopback, err = r.processInviteEvent(ctx, r, request.InputInviteEvents[i]); err != nil { return err } // The processInviteEvent function can optionally return a @@ -71,7 +71,7 @@ func (r *RoomserverInternalAPI) InputRoomEvents( } } for i := range request.InputRoomEvents { - if response.EventID, err = processRoomEvent(ctx, r.DB, r, request.InputRoomEvents[i]); err != nil { + if response.EventID, err = r.processRoomEvent(ctx, request.InputRoomEvents[i]); err != nil { return err } } diff --git a/roomserver/internal/input_events.go b/roomserver/internal/input_events.go index f5c678ca6..b6ce655ae 100644 --- a/roomserver/internal/input_events.go +++ b/roomserver/internal/input_events.go @@ -31,21 +31,13 @@ import ( log "github.com/sirupsen/logrus" ) -// OutputRoomEventWriter has the APIs needed to write an event to the output logs. -type OutputRoomEventWriter interface { - // Write a list of events for a room - WriteOutputEvents(roomID string, updates []api.OutputEvent) error -} - // processRoomEvent can only be called once at a time // // TODO(#375): This should be rewritten to allow concurrent calls. The // difficulty is in ensuring that we correctly annotate events with the correct // state deltas when sending to kafka streams -func processRoomEvent( +func (r *RoomserverInternalAPI) processRoomEvent( ctx context.Context, - db storage.Database, - ow OutputRoomEventWriter, input api.InputRoomEvent, ) (eventID string, err error) { // Parse and validate the event JSON @@ -54,7 +46,7 @@ func processRoomEvent( // Check that the event passes authentication checks and work out // the numeric IDs for the auth events. - authEventNIDs, err := checkAuthEvents(ctx, db, headered, input.AuthEventIDs) + authEventNIDs, err := checkAuthEvents(ctx, r.DB, headered, input.AuthEventIDs) if err != nil { logrus.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("processRoomEvent.checkAuthEvents failed for event") return @@ -63,7 +55,7 @@ func processRoomEvent( // If we don't have a transaction ID then get one. if input.TransactionID != nil { tdID := input.TransactionID - eventID, err = db.GetTransactionEventID( + eventID, err = r.DB.GetTransactionEventID( ctx, tdID.TransactionID, tdID.SessionID, event.Sender(), ) // On error OR event with the transaction already processed/processesing @@ -73,7 +65,7 @@ func processRoomEvent( } // Store the event. - roomNID, stateAtEvent, err := db.StoreEvent(ctx, event, input.TransactionID, authEventNIDs) + roomNID, stateAtEvent, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs) if err != nil { return } @@ -93,16 +85,14 @@ func processRoomEvent( if stateAtEvent.BeforeStateSnapshotNID == 0 { // We haven't calculated a state for this event yet. // Lets calculate one. - err = calculateAndSetState(ctx, db, input, roomNID, &stateAtEvent, event) + err = r.calculateAndSetState(ctx, input, roomNID, &stateAtEvent, event) if err != nil { return } } - if err = updateLatestEvents( + if err = r.updateLatestEvents( ctx, // context - db, // roomserver database - ow, // output event writer roomNID, // room NID to update stateAtEvent, // state at event (below) event, // event @@ -116,16 +106,15 @@ func processRoomEvent( return event.EventID(), nil } -func calculateAndSetState( +func (r *RoomserverInternalAPI) calculateAndSetState( ctx context.Context, - db storage.Database, input api.InputRoomEvent, roomNID types.RoomNID, stateAtEvent *types.StateAtEvent, event gomatrixserverlib.Event, ) error { var err error - roomState := state.NewStateResolution(db) + roomState := state.NewStateResolution(r.DB) if input.HasState { // TODO: Check here if we think we're in the room already. @@ -134,11 +123,11 @@ func calculateAndSetState( // We've been told what the state at the event is so we don't need to calculate it. // Check that those state events are in the database and store the state. var entries []types.StateEntry - if entries, err = db.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { + if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { return err } - if stateAtEvent.BeforeStateSnapshotNID, err = db.AddState(ctx, roomNID, nil, entries); err != nil { + if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil { return err } } else { @@ -149,12 +138,11 @@ func calculateAndSetState( return err } } - return db.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) + return r.DB.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) } -func processInviteEvent( +func (r *RoomserverInternalAPI) processInviteEvent( ctx context.Context, - db storage.Database, ow *RoomserverInternalAPI, input api.InputInviteEvent, ) (*api.InputRoomEvent, error) { @@ -172,7 +160,10 @@ func processInviteEvent( "target_user_id": targetUserID, }).Info("processing invite event") - updater, err := db.MembershipUpdater(ctx, roomID, targetUserID, input.RoomVersion) + _, domain, _ := gomatrixserverlib.SplitID('@', targetUserID) + targetLocal := domain == r.Cfg.Matrix.ServerName + + updater, err := r.DB.MembershipUpdater(ctx, roomID, targetUserID, targetLocal, input.RoomVersion) if err != nil { return nil, err } @@ -239,7 +230,7 @@ func processInviteEvent( // up from local data (which is most likely to be if the event came // from the CS API). If we know about the room then we can insert // the invite room state, if we don't then we just fail quietly. - if irs, ierr := buildInviteStrippedState(ctx, db, input); ierr == nil { + if irs, ierr := buildInviteStrippedState(ctx, r.DB, input); ierr == nil { if err = event.SetUnsignedField("invite_room_state", irs); err != nil { return nil, err } diff --git a/roomserver/internal/input_latest_events.go b/roomserver/internal/input_latest_events.go index 6eeeedab0..d7c9a5cb6 100644 --- a/roomserver/internal/input_latest_events.go +++ b/roomserver/internal/input_latest_events.go @@ -23,7 +23,6 @@ import ( "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" - "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -46,17 +45,15 @@ import ( // 7 <----- latest // // Can only be called once at a time -func updateLatestEvents( +func (r *RoomserverInternalAPI) updateLatestEvents( ctx context.Context, - db storage.Database, - ow OutputRoomEventWriter, roomNID types.RoomNID, stateAtEvent types.StateAtEvent, event gomatrixserverlib.Event, sendAsServer string, transactionID *api.TransactionID, ) (err error) { - updater, err := db.GetLatestEventsForUpdate(ctx, roomNID) + updater, err := r.DB.GetLatestEventsForUpdate(ctx, roomNID) if err != nil { return } @@ -70,9 +67,8 @@ func updateLatestEvents( u := latestEventsUpdater{ ctx: ctx, - db: db, + api: r, updater: updater, - ow: ow, roomNID: roomNID, stateAtEvent: stateAtEvent, event: event, @@ -94,9 +90,8 @@ func updateLatestEvents( // when there are so many variables to pass around. type latestEventsUpdater struct { ctx context.Context - db storage.Database + api *RoomserverInternalAPI updater types.RoomRecentEventsUpdater - ow OutputRoomEventWriter roomNID types.RoomNID stateAtEvent types.StateAtEvent event gomatrixserverlib.Event @@ -181,7 +176,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { // If we need to generate any output events then here's where we do it. // TODO: Move this! - updates, err := updateMemberships(u.ctx, u.db, u.updater, u.removed, u.added) + updates, err := u.api.updateMemberships(u.ctx, u.updater, u.removed, u.added) if err != nil { return err } @@ -200,7 +195,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { // send the event asynchronously but we would need to ensure that 1) the events are written to the log in // the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the // necessary bookkeeping we'll keep the event sending synchronous for now. - if err = u.ow.WriteOutputEvents(u.event.RoomID(), updates); err != nil { + if err = u.api.WriteOutputEvents(u.event.RoomID(), updates); err != nil { return err } @@ -213,7 +208,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { func (u *latestEventsUpdater) latestState() error { var err error - roomState := state.NewStateResolution(u.db) + roomState := state.NewStateResolution(u.api.DB) // Get a list of the current latest events. latestStateAtEvents := make([]types.StateAtEvent, len(u.latest)) @@ -303,7 +298,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) latestEventIDs[i] = u.latest[i].EventID } - roomVersion, err := u.db.GetRoomVersionForRoom(u.ctx, u.event.RoomID()) + roomVersion, err := u.api.DB.GetRoomVersionForRoom(u.ctx, u.event.RoomID()) if err != nil { return nil, err } @@ -329,7 +324,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) stateEventNIDs = append(stateEventNIDs, entry.EventNID) } stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))] - eventIDMap, err := u.db.EventIDs(u.ctx, stateEventNIDs) + eventIDMap, err := u.api.DB.EventIDs(u.ctx, stateEventNIDs) if err != nil { return nil, err } diff --git a/roomserver/internal/input_membership.go b/roomserver/internal/input_membership.go index 666e7ebcc..2383ec567 100644 --- a/roomserver/internal/input_membership.go +++ b/roomserver/internal/input_membership.go @@ -19,7 +19,6 @@ import ( "fmt" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -28,9 +27,8 @@ import ( // user affected by a change in the current state of the room. // Returns a list of output events to write to the kafka log to inform the // consumers about the invites added or retired by the change in current state. -func updateMemberships( +func (r *RoomserverInternalAPI) updateMemberships( ctx context.Context, - db storage.Database, updater types.RoomRecentEventsUpdater, removed, added []types.StateEntry, ) ([]api.OutputEvent, error) { @@ -48,7 +46,7 @@ func updateMemberships( // Load the event JSON so we can look up the "membership" key. // TODO: Maybe add a membership key to the events table so we can load that // key without having to load the entire event JSON? - events, err := db.Events(ctx, eventNIDs) + events, err := r.DB.Events(ctx, eventNIDs) if err != nil { return nil, err } @@ -59,6 +57,7 @@ func updateMemberships( var ae *gomatrixserverlib.Event var re *gomatrixserverlib.Event targetUserNID := change.EventStateKeyNID + targetLocal := false if change.removedEventNID != 0 { ev, _ := eventMap(events).lookup(change.removedEventNID) if ev != nil { @@ -71,15 +70,17 @@ func updateMemberships( ae = &ev.Event } } - if updates, err = updateMembership(updater, targetUserNID, re, ae, updates); err != nil { + if updates, err = r.updateMembership(updater, targetUserNID, targetLocal, re, ae, updates); err != nil { return nil, err } } return updates, nil } -func updateMembership( - updater types.RoomRecentEventsUpdater, targetUserNID types.EventStateKeyNID, +func (r *RoomserverInternalAPI) updateMembership( + updater types.RoomRecentEventsUpdater, + targetUserNID types.EventStateKeyNID, + targetLocal bool, remove, add *gomatrixserverlib.Event, updates []api.OutputEvent, ) ([]api.OutputEvent, error) { @@ -113,7 +114,13 @@ func updateMembership( return updates, nil } - mu, err := updater.MembershipUpdater(targetUserNID) + targetLocal = false + if statekey := add.StateKey(); statekey != nil { + _, domain, _ := gomatrixserverlib.SplitID('@', *statekey) + targetLocal = domain == r.Cfg.Matrix.ServerName + } + + mu, err := updater.MembershipUpdater(targetUserNID, targetLocal) if err != nil { return nil, err } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index fb39eca63..f14b64666 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -83,7 +83,7 @@ type Database interface { GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) GetCreatorIDForAlias(ctx context.Context, alias string) (string, error) RemoveRoomAlias(ctx context.Context, alias string) error - MembershipUpdater(ctx context.Context, roomID, targetUserID string, roomVersion gomatrixserverlib.RoomVersion) (types.MembershipUpdater, error) + MembershipUpdater(ctx context.Context, roomID, targetUserID string, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion) (types.MembershipUpdater, error) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error) GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 9c8a4c259..f8bff18d5 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -59,6 +59,10 @@ CREATE TABLE IF NOT EXISTS roomserver_membership ( -- This NID is updated if the join event gets updated (e.g. profile update), -- or if the user leaves/joins the room. event_nid BIGINT NOT NULL DEFAULT 0, + -- Local target is true if the target_nid refers to a local user rather than + -- a federated one. This is an optimisation for resetting state on federated + -- room joins. + local_target BOOLEAN NOT NULL DEFAULT false, UNIQUE (room_nid, target_nid) ); ` @@ -66,8 +70,8 @@ CREATE TABLE IF NOT EXISTS roomserver_membership ( // Insert a row in to membership table so that it can be locked by the // SELECT FOR UPDATE const insertMembershipSQL = "" + - "INSERT INTO roomserver_membership (room_nid, target_nid)" + - " VALUES ($1, $2)" + + "INSERT INTO roomserver_membership (room_nid, target_nid, local_target)" + + " VALUES ($1, $2, $3)" + " ON CONFLICT DO NOTHING" const selectMembershipFromRoomAndTargetSQL = "" + @@ -118,9 +122,10 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { func (s *membershipStatements) insertMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + localTarget bool, ) error { stmt := common.TxStmt(txn, s.insertMembershipStmt) - _, err := stmt.ExecContext(ctx, roomNID, targetUserNID) + _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget) return err } diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 1d825ecc2..ea44e3f63 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -459,8 +459,8 @@ func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error return u.d.statements.updateEventSentToOutput(u.ctx, u.txn, eventNID) } -func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (types.MembershipUpdater, error) { - return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID) +func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (types.MembershipUpdater, error) { + return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID, targetLocal) } // RoomNID implements query.RoomserverQueryAPIDB @@ -558,7 +558,7 @@ func (d *Database) StateEntriesForTuples( // MembershipUpdater implements input.RoomEventDatabase func (d *Database) MembershipUpdater( ctx context.Context, roomID, targetUserID string, - roomVersion gomatrixserverlib.RoomVersion, + targetLocal bool, roomVersion gomatrixserverlib.RoomVersion, ) (types.MembershipUpdater, error) { txn, err := d.db.Begin() if err != nil { @@ -581,7 +581,7 @@ func (d *Database) MembershipUpdater( return nil, err } - updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID) + updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID, targetLocal) if err != nil { return nil, err } @@ -603,9 +603,10 @@ func (d *Database) membershipUpdaterTxn( txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + targetLocal bool, ) (types.MembershipUpdater, error) { - if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil { + if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 7ae28e4b8..0e3201aa4 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -38,6 +38,7 @@ const membershipSchema = ` sender_nid INTEGER NOT NULL DEFAULT 0, membership_nid INTEGER NOT NULL DEFAULT 1, event_nid INTEGER NOT NULL DEFAULT 0, + local_target BOOLEAN NOT NULL DEFAULT false, UNIQUE (room_nid, target_nid) ); ` @@ -45,8 +46,8 @@ const membershipSchema = ` // Insert a row in to membership table so that it can be locked by the // SELECT FOR UPDATE const insertMembershipSQL = "" + - "INSERT INTO roomserver_membership (room_nid, target_nid)" + - " VALUES ($1, $2)" + + "INSERT INTO roomserver_membership (room_nid, target_nid, local_target)" + + " VALUES ($1, $2, $3)" + " ON CONFLICT DO NOTHING" const selectMembershipFromRoomAndTargetSQL = "" + @@ -97,9 +98,10 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { func (s *membershipStatements) insertMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + localTarget bool, ) error { stmt := common.TxStmt(txn, s.insertMembershipStmt) - _, err := stmt.ExecContext(ctx, roomNID, targetUserNID) + _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget) return err } diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index e77fea9cf..c00702a05 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -569,9 +569,9 @@ func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error return err } -func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (mu types.MembershipUpdater, err error) { +func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (mu types.MembershipUpdater, err error) { err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { - mu, err = u.d.membershipUpdaterTxn(u.ctx, txn, u.roomNID, targetUserNID) + mu, err = u.d.membershipUpdaterTxn(u.ctx, txn, u.roomNID, targetUserNID, targetLocal) return err }) return @@ -680,7 +680,7 @@ func (d *Database) StateEntriesForTuples( // MembershipUpdater implements input.RoomEventDatabase func (d *Database) MembershipUpdater( ctx context.Context, roomID, targetUserID string, - roomVersion gomatrixserverlib.RoomVersion, + targetLocal bool, roomVersion gomatrixserverlib.RoomVersion, ) (updater types.MembershipUpdater, err error) { var txn *sql.Tx txn, err = d.db.Begin() @@ -716,7 +716,7 @@ func (d *Database) MembershipUpdater( return nil, err } - updater, err = d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID) + updater, err = d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID, targetLocal) if err != nil { return nil, err } @@ -738,9 +738,10 @@ func (d *Database) membershipUpdaterTxn( txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + targetLocal bool, ) (types.MembershipUpdater, error) { - if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil { + if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { return nil, err } diff --git a/roomserver/types/types.go b/roomserver/types/types.go index da83f614c..477664f87 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -172,7 +172,7 @@ type RoomRecentEventsUpdater interface { MarkEventAsSent(eventNID EventNID) error // Build a membership updater for the target user in this room. // It will share the same transaction as this updater. - MembershipUpdater(targetUserNID EventStateKeyNID) (MembershipUpdater, error) + MembershipUpdater(targetUserNID EventStateKeyNID, targetLocal bool) (MembershipUpdater, error) // Implements Transaction so it can be committed or rolledback common.Transaction }