From c5e3ac95375951198ef1a3978c128b54774a6cce Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 19 May 2020 13:27:45 +0100 Subject: [PATCH] Check if we're in the room already before resetting latest events/state --- roomserver/internal/input_events.go | 6 ++- roomserver/internal/query.go | 4 +- roomserver/internal/query_backfill.go | 2 +- roomserver/storage/interface.go | 2 +- .../storage/postgres/membership_table.go | 40 ++++++++++----- roomserver/storage/postgres/storage.go | 6 +-- .../storage/sqlite3/membership_table.go | 49 ++++++++++++++----- roomserver/storage/sqlite3/storage.go | 6 +-- 8 files changed, 81 insertions(+), 34 deletions(-) diff --git a/roomserver/internal/input_events.go b/roomserver/internal/input_events.go index b6ce655ae..864ecf3fa 100644 --- a/roomserver/internal/input_events.go +++ b/roomserver/internal/input_events.go @@ -117,8 +117,12 @@ func (r *RoomserverInternalAPI) calculateAndSetState( roomState := state.NewStateResolution(r.DB) if input.HasState { - // TODO: Check here if we think we're in the room already. + // Check here if we think we're in the room already. stateAtEvent.Overwrite = true + var joinEventNIDs []types.EventNID + if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true, true); err == nil { + stateAtEvent.Overwrite = len(joinEventNIDs) == 0 + } // We've been told what the state at the event is so we don't need to calculate it. // Check that those state events are in the database and store the state. diff --git a/roomserver/internal/query.go b/roomserver/internal/query.go index 8a8c8e7d0..d4065d6ba 100644 --- a/roomserver/internal/query.go +++ b/roomserver/internal/query.go @@ -267,7 +267,7 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom( var stateEntries []types.StateEntry if stillInRoom { var eventNIDs []types.EventNID - eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, request.JoinedOnly) + eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, request.JoinedOnly, false) if err != nil { return err } @@ -592,7 +592,7 @@ func (r *RoomserverInternalAPI) isServerCurrentlyInRoom(ctx context.Context, ser return false, err } - eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true) + eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true, false) if err != nil { return false, err } diff --git a/roomserver/internal/query_backfill.go b/roomserver/internal/query_backfill.go index d42038e74..1e9f8e88e 100644 --- a/roomserver/internal/query_backfill.go +++ b/roomserver/internal/query_backfill.go @@ -270,7 +270,7 @@ func joinEventsFromHistoryVisibility( if err != nil { return nil, err } - joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, roomNID, true) + joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, roomNID, true, false) if err != nil { return nil, err } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index f14b64666..1e0232d20 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -85,7 +85,7 @@ type Database interface { RemoveRoomAlias(ctx context.Context, alias string) error MembershipUpdater(ctx context.Context, roomID, targetUserID string, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion) (types.MembershipUpdater, error) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error) - GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error) + GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool) ([]types.EventNID, error) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) GetRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) } diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index f8bff18d5..b46cd9458 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -62,7 +62,7 @@ CREATE TABLE IF NOT EXISTS roomserver_membership ( -- Local target is true if the target_nid refers to a local user rather than -- a federated one. This is an optimisation for resetting state on federated -- room joins. - local_target BOOLEAN NOT NULL DEFAULT false, + target_local BOOLEAN NOT NULL DEFAULT false, UNIQUE (room_nid, target_nid) ); ` @@ -70,7 +70,7 @@ CREATE TABLE IF NOT EXISTS roomserver_membership ( // Insert a row in to membership table so that it can be locked by the // SELECT FOR UPDATE const insertMembershipSQL = "" + - "INSERT INTO roomserver_membership (room_nid, target_nid, local_target)" + + "INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" + " VALUES ($1, $2, $3)" + " ON CONFLICT DO NOTHING" @@ -82,10 +82,20 @@ const selectMembershipsFromRoomAndMembershipSQL = "" + "SELECT event_nid FROM roomserver_membership" + " WHERE room_nid = $1 AND membership_nid = $2" +const selectLocalMembershipsFromRoomAndMembershipSQL = "" + + "SELECT event_nid FROM roomserver_membership" + + " WHERE room_nid = $1 AND membership_nid = $2" + + " AND target_local = true" + const selectMembershipsFromRoomSQL = "" + "SELECT event_nid FROM roomserver_membership" + " WHERE room_nid = $1" +const selectLocalMembershipsFromRoomSQL = "" + + "SELECT event_nid FROM roomserver_membership" + + " WHERE room_nid = $1" + + " AND target_local = true" + const selectMembershipForUpdateSQL = "" + "SELECT membership_nid FROM roomserver_membership" + " WHERE room_nid = $1 AND target_nid = $2 FOR UPDATE" @@ -95,12 +105,14 @@ const updateMembershipSQL = "" + " WHERE room_nid = $1 AND target_nid = $2" type membershipStatements struct { - insertMembershipStmt *sql.Stmt - selectMembershipForUpdateStmt *sql.Stmt - selectMembershipFromRoomAndTargetStmt *sql.Stmt - selectMembershipsFromRoomAndMembershipStmt *sql.Stmt - selectMembershipsFromRoomStmt *sql.Stmt - updateMembershipStmt *sql.Stmt + insertMembershipStmt *sql.Stmt + selectMembershipForUpdateStmt *sql.Stmt + selectMembershipFromRoomAndTargetStmt *sql.Stmt + selectMembershipsFromRoomAndMembershipStmt *sql.Stmt + selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt + selectMembershipsFromRoomStmt *sql.Stmt + selectLocalMembershipsFromRoomStmt *sql.Stmt + updateMembershipStmt *sql.Stmt } func (s *membershipStatements) prepare(db *sql.DB) (err error) { @@ -114,7 +126,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, {&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL}, + {&s.selectLocalMembershipsFromRoomAndMembershipStmt, selectLocalMembershipsFromRoomAndMembershipSQL}, {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, + {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, }.prepare(db) } @@ -150,7 +164,7 @@ func (s *membershipStatements) selectMembershipFromRoomAndTarget( } func (s *membershipStatements) selectMembershipsFromRoom( - ctx context.Context, roomNID types.RoomNID, + ctx context.Context, roomNID types.RoomNID, localOnly bool, ) (eventNIDs []types.EventNID, err error) { rows, err := s.selectMembershipsFromRoomStmt.QueryContext(ctx, roomNID) if err != nil { @@ -170,10 +184,14 @@ func (s *membershipStatements) selectMembershipsFromRoom( func (s *membershipStatements) selectMembershipsFromRoomAndMembership( ctx context.Context, - roomNID types.RoomNID, membership membershipState, + roomNID types.RoomNID, membership membershipState, localOnly bool, ) (eventNIDs []types.EventNID, err error) { + var rows *sql.Rows stmt := s.selectMembershipsFromRoomAndMembershipStmt - rows, err := stmt.QueryContext(ctx, roomNID, membership) + if localOnly { + stmt = s.selectLocalMembershipsFromRoomAndMembershipStmt + } + rows, err = stmt.QueryContext(ctx, roomNID, membership) if err != nil { return } diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index ea44e3f63..d451d6650 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -749,15 +749,15 @@ func (d *Database) GetMembership( // GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB func (d *Database) GetMembershipEventNIDsForRoom( - ctx context.Context, roomNID types.RoomNID, joinOnly bool, + ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, ) ([]types.EventNID, error) { if joinOnly { return d.statements.selectMembershipsFromRoomAndMembership( - ctx, roomNID, membershipStateJoin, + ctx, roomNID, membershipStateJoin, localOnly, ) } - return d.statements.selectMembershipsFromRoom(ctx, roomNID) + return d.statements.selectMembershipsFromRoom(ctx, roomNID, localOnly) } // EventsFromIDs implements query.RoomserverQueryAPIEventDB diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 0e3201aa4..ca4d8fbe9 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -38,7 +38,7 @@ const membershipSchema = ` sender_nid INTEGER NOT NULL DEFAULT 0, membership_nid INTEGER NOT NULL DEFAULT 1, event_nid INTEGER NOT NULL DEFAULT 0, - local_target BOOLEAN NOT NULL DEFAULT false, + target_local BOOLEAN NOT NULL DEFAULT false, UNIQUE (room_nid, target_nid) ); ` @@ -46,7 +46,7 @@ const membershipSchema = ` // Insert a row in to membership table so that it can be locked by the // SELECT FOR UPDATE const insertMembershipSQL = "" + - "INSERT INTO roomserver_membership (room_nid, target_nid, local_target)" + + "INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" + " VALUES ($1, $2, $3)" + " ON CONFLICT DO NOTHING" @@ -58,10 +58,20 @@ const selectMembershipsFromRoomAndMembershipSQL = "" + "SELECT event_nid FROM roomserver_membership" + " WHERE room_nid = $1 AND membership_nid = $2" +const selectLocalMembershipsFromRoomAndMembershipSQL = "" + + "SELECT event_nid FROM roomserver_membership" + + " WHERE room_nid = $1 AND membership_nid = $2" + + " AND target_local = true" + const selectMembershipsFromRoomSQL = "" + "SELECT event_nid FROM roomserver_membership" + " WHERE room_nid = $1" +const selectLocalMembershipsFromRoomSQL = "" + + "SELECT event_nid FROM roomserver_membership" + + " WHERE room_nid = $1" + + " AND target_local = true" + const selectMembershipForUpdateSQL = "" + "SELECT membership_nid FROM roomserver_membership" + " WHERE room_nid = $1 AND target_nid = $2" @@ -71,12 +81,14 @@ const updateMembershipSQL = "" + " WHERE room_nid = $4 AND target_nid = $5" type membershipStatements struct { - insertMembershipStmt *sql.Stmt - selectMembershipForUpdateStmt *sql.Stmt - selectMembershipFromRoomAndTargetStmt *sql.Stmt - selectMembershipsFromRoomAndMembershipStmt *sql.Stmt - selectMembershipsFromRoomStmt *sql.Stmt - updateMembershipStmt *sql.Stmt + insertMembershipStmt *sql.Stmt + selectMembershipForUpdateStmt *sql.Stmt + selectMembershipFromRoomAndTargetStmt *sql.Stmt + selectMembershipsFromRoomAndMembershipStmt *sql.Stmt + selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt + selectMembershipsFromRoomStmt *sql.Stmt + selectLocalMembershipsFromRoomStmt *sql.Stmt + updateMembershipStmt *sql.Stmt } func (s *membershipStatements) prepare(db *sql.DB) (err error) { @@ -90,7 +102,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, {&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL}, + {&s.selectLocalMembershipsFromRoomAndMembershipStmt, selectLocalMembershipsFromRoomAndMembershipSQL}, {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, + {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, }.prepare(db) } @@ -129,9 +143,14 @@ func (s *membershipStatements) selectMembershipFromRoomAndTarget( func (s *membershipStatements) selectMembershipsFromRoom( ctx context.Context, txn *sql.Tx, - roomNID types.RoomNID, + roomNID types.RoomNID, localOnly bool, ) (eventNIDs []types.EventNID, err error) { - selectStmt := common.TxStmt(txn, s.selectMembershipsFromRoomStmt) + var selectStmt *sql.Stmt + if localOnly { + selectStmt = common.TxStmt(txn, s.selectLocalMembershipsFromRoomStmt) + } else { + selectStmt = common.TxStmt(txn, s.selectMembershipsFromRoomStmt) + } rows, err := selectStmt.QueryContext(ctx, roomNID) if err != nil { return nil, err @@ -147,11 +166,17 @@ func (s *membershipStatements) selectMembershipsFromRoom( } return } + func (s *membershipStatements) selectMembershipsFromRoomAndMembership( ctx context.Context, txn *sql.Tx, - roomNID types.RoomNID, membership membershipState, + roomNID types.RoomNID, membership membershipState, localOnly bool, ) (eventNIDs []types.EventNID, err error) { - stmt := common.TxStmt(txn, s.selectMembershipsFromRoomAndMembershipStmt) + var stmt *sql.Stmt + if localOnly { + stmt = common.TxStmt(txn, s.selectLocalMembershipsFromRoomAndMembershipStmt) + } else { + stmt = common.TxStmt(txn, s.selectMembershipsFromRoomAndMembershipStmt) + } rows, err := stmt.QueryContext(ctx, roomNID, membership) if err != nil { return diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index c00702a05..209922fa2 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -897,17 +897,17 @@ func (d *Database) GetMembership( // GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB func (d *Database) GetMembershipEventNIDsForRoom( - ctx context.Context, roomNID types.RoomNID, joinOnly bool, + ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, ) (eventNIDs []types.EventNID, err error) { err = common.WithTransaction(d.db, func(txn *sql.Tx) error { if joinOnly { eventNIDs, err = d.statements.selectMembershipsFromRoomAndMembership( - ctx, txn, roomNID, membershipStateJoin, + ctx, txn, roomNID, membershipStateJoin, localOnly, ) return nil } - eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID) + eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID, localOnly) return nil }) return