diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 9fafdbede..b540fbc69 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -22,6 +22,7 @@ import ( "strings" "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -91,35 +92,37 @@ type currentRoomStateStatements struct { selectStateEventStmt *sql.Stmt } -func (s *currentRoomStateStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) { - s.streamIDStatements = streamID - _, err = db.Exec(currentRoomStateSchema) +func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) { + s := ¤tRoomStateStatements{ + streamIDStatements: streamID, + } + _, err := db.Exec(currentRoomStateSchema) if err != nil { - return + return nil, err } if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil { - return + return nil, err } if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil { - return + return nil, err } if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { - return + return nil, err } if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil { - return + return nil, err } if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { - return + return nil, err } if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil { - return + return nil, err } - return + return s, nil } // JoinedMemberLists returns a map of room ID to a list of joined user IDs. -func (s *currentRoomStateStatements) selectJoinedUsers( +func (s *currentRoomStateStatements) SelectJoinedUsers( ctx context.Context, ) (map[string][]string, error) { rows, err := s.selectJoinedUsersStmt.QueryContext(ctx) @@ -143,7 +146,7 @@ func (s *currentRoomStateStatements) selectJoinedUsers( } // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. -func (s *currentRoomStateStatements) selectRoomIDsWithMembership( +func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( ctx context.Context, txn *sql.Tx, userID string, @@ -168,7 +171,7 @@ func (s *currentRoomStateStatements) selectRoomIDsWithMembership( } // CurrentState returns all the current state events for the given room. -func (s *currentRoomStateStatements) selectCurrentState( +func (s *currentRoomStateStatements) SelectCurrentState( ctx context.Context, txn *sql.Tx, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, ) ([]gomatrixserverlib.HeaderedEvent, error) { @@ -189,7 +192,7 @@ func (s *currentRoomStateStatements) selectCurrentState( return rowsToEvents(rows) } -func (s *currentRoomStateStatements) deleteRoomStateByEventID( +func (s *currentRoomStateStatements) DeleteRoomStateByEventID( ctx context.Context, txn *sql.Tx, eventID string, ) error { stmt := common.TxStmt(txn, s.deleteRoomStateByEventIDStmt) @@ -197,7 +200,7 @@ func (s *currentRoomStateStatements) deleteRoomStateByEventID( return err } -func (s *currentRoomStateStatements) upsertRoomState( +func (s *currentRoomStateStatements) UpsertRoomState( ctx context.Context, txn *sql.Tx, event gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition, ) error { @@ -231,7 +234,7 @@ func (s *currentRoomStateStatements) upsertRoomState( return err } -func (s *currentRoomStateStatements) selectEventsWithEventIDs( +func (s *currentRoomStateStatements) SelectEventsWithEventIDs( ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StreamEvent, error) { iEventIDs := make([]interface{}, len(eventIDs)) @@ -264,7 +267,7 @@ func rowsToEvents(rows *sql.Rows) ([]gomatrixserverlib.HeaderedEvent, error) { return result, nil } -func (s *currentRoomStateStatements) selectStateEvent( +func (s *currentRoomStateStatements) SelectStateEvent( ctx context.Context, roomID, evType, stateKey string, ) (*gomatrixserverlib.HeaderedEvent, error) { stmt := s.selectStateEventStmt diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 72a62e990..182cbb2d7 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -54,9 +54,8 @@ type SyncServerDatasource struct { shared.Database db *sql.DB common.PartitionOffsetStatements - streamID streamIDStatements - roomstate currentRoomStateStatements - topology outputRoomEventsTopologyStatements + streamID streamIDStatements + topology outputRoomEventsTopologyStatements } // NewSyncServerDatasource creates a new sync server database @@ -99,7 +98,8 @@ func (d *SyncServerDatasource) prepare() (err error) { if err != nil { return err } - if err = d.roomstate.prepare(d.db, &d.streamID); err != nil { + roomState, err := NewSqliteCurrentRoomStateTable(d.db, &d.streamID) + if err != nil { return err } invites, err := NewSqliteInvitesTable(d.db, &d.streamID) @@ -119,16 +119,12 @@ func (d *SyncServerDatasource) prepare() (err error) { AccountData: accountData, OutputEvents: events, BackwardExtremities: bwExtrem, + CurrentRoomState: roomState, EDUCache: cache.New(), } return nil } -// AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. -func (d *SyncServerDatasource) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { - return d.roomstate.selectJoinedUsers(ctx) -} - // handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of // the events listed in the event's 'prev_events'. This function also updates the backwards extremities table // to account for the fact that the given event is no longer a backwards extremity, but may be marked as such. @@ -210,7 +206,7 @@ func (d *SyncServerDatasource) updateRoomState( ) error { // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. for _, eventID := range removedEventIDs { - if err := d.roomstate.deleteRoomStateByEventID(ctx, txn, eventID); err != nil { + if err := d.Database.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil { return err } } @@ -228,7 +224,7 @@ func (d *SyncServerDatasource) updateRoomState( } membership = &value } - if err := d.roomstate.upsertRoomState(ctx, txn, event, membership, pduPosition); err != nil { + if err := d.Database.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership, pduPosition); err != nil { return err } } @@ -249,28 +245,6 @@ func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (tok types.Stre return } -// GetStateEvent returns the Matrix state event of a given type for a given room with a given state key -// If no event could be found, returns nil -// If there was an issue during the retrieval, returns an error -func (d *SyncServerDatasource) GetStateEvent( - ctx context.Context, roomID, evType, stateKey string, -) (*gomatrixserverlib.HeaderedEvent, error) { - return d.roomstate.selectStateEvent(ctx, roomID, evType, stateKey) -} - -// GetStateEventsForRoom fetches the state events for a given room. -// Returns an empty slice if no state events could be found for this room. -// Returns an error if there was an issue with the retrieval. -func (d *SyncServerDatasource) GetStateEventsForRoom( - ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, -) (stateEvents []gomatrixserverlib.HeaderedEvent, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilterPart) - return err - }) - return -} - // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the // given extremities and limit. func (d *SyncServerDatasource) GetEventsInTopologicalRange( @@ -530,7 +504,7 @@ func (d *SyncServerDatasource) IncrementalSync( ctx, device, fromPos.PDUPosition(), toPos.PDUPosition(), numRecentEventsPerRoom, wantFullState, res, ) } else { - joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership( + joinedRoomIDs, err = d.Database.CurrentRoomState.SelectRoomIDsWithMembership( ctx, nil, device.UserID, gomatrixserverlib.Join, ) } @@ -585,7 +559,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( res = types.NewResponse(*toPos) // Extract room state and recent events for all rooms the user is joined to. - joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) + joinedRoomIDs, err = d.Database.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) if err != nil { return } @@ -595,7 +569,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( // Build up a /sync response. Add joined rooms. for _, roomID := range joinedRoomIDs { var stateEvents []gomatrixserverlib.HeaderedEvent - stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID, &stateFilterPart) + stateEvents, err = d.Database.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, &stateFilterPart) if err != nil { return } @@ -849,7 +823,7 @@ func (d *SyncServerDatasource) fetchMissingStateEvents( // If they are missing from the events table then they should be state // events that we received from outside the main event stream. // These should be in the room state table. - stateEvents, err := d.roomstate.selectEventsWithEventIDs(ctx, txn, missing) + stateEvents, err := d.Database.CurrentRoomState.SelectEventsWithEventIDs(ctx, txn, missing) if err != nil { return nil, err @@ -921,7 +895,7 @@ func (d *SyncServerDatasource) getStateDeltas( } // Add in currently joined rooms - joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) + joinedRoomIDs, err := d.Database.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) if err != nil { return nil, nil, err } @@ -945,7 +919,7 @@ func (d *SyncServerDatasource) getStateDeltasForFullStateSync( fromPos, toPos types.StreamPosition, userID string, stateFilterPart *gomatrixserverlib.StateFilter, ) ([]stateDelta, []string, error) { - joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) + joinedRoomIDs, err := d.Database.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) if err != nil { return nil, nil, err } @@ -1000,7 +974,7 @@ func (d *SyncServerDatasource) currentStateStreamEventsForRoom( ctx context.Context, txn *sql.Tx, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, ) ([]types.StreamEvent, error) { - allState, err := d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilterPart) + allState, err := d.Database.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, stateFilterPart) if err != nil { return nil, err }