From 250a0ee94659bd6e1100b9358c42ab267f842eaa Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 2 Feb 2022 14:12:55 +0000 Subject: [PATCH] Fill some gaps --- roomserver/internal/input/input_events.go | 2 +- .../internal/input/input_latest_events.go | 2 +- roomserver/state/state.go | 17 +++++++++++++--- .../postgres/event_state_keys_table.go | 10 ++++++---- .../storage/postgres/event_types_table.go | 5 +++-- roomserver/storage/shared/storage.go | 12 +++++------ .../storage/sqlite3/event_state_keys_table.go | 20 +++++++++++++------ .../storage/sqlite3/event_types_table.go | 5 +++-- roomserver/storage/tables/interface.go | 6 +++--- 9 files changed, 51 insertions(+), 28 deletions(-) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 83d972aee..3e9a4c414 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -483,7 +483,7 @@ func (r *Inputer) calculateAndSetState( isRejected bool, ) error { var err error - roomState := state.NewStateResolution(r.DB, roomInfo) + roomState := state.NewStateResolution(updater, roomInfo) if input.HasState { // Check here if we think we're in the room already. diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index ca154e7b8..e5041b0e3 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -191,7 +191,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { func (u *latestEventsUpdater) latestState() error { var err error - roomState := state.NewStateResolution(u.api.DB, u.roomInfo) + roomState := state.NewStateResolution(u.updater, u.roomInfo) // Work out if the state at the extremities has actually changed // or not. If they haven't then we won't bother doing all of the diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 15d592b46..e5f69521e 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -22,7 +22,6 @@ import ( "sort" "time" - "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" @@ -30,13 +29,25 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) +type StateResolutionStorage interface { + EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) + EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) + StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) + StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) + SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) + StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) + StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) + AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) + Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) +} + type StateResolution struct { - db storage.Database + db StateResolutionStorage roomInfo *types.RoomInfo events map[types.EventNID]*gomatrixserverlib.Event } -func NewStateResolution(db storage.Database, roomInfo *types.RoomInfo) StateResolution { +func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo) StateResolution { return StateResolution{ db: db, roomInfo: roomInfo, diff --git a/roomserver/storage/postgres/event_state_keys_table.go b/roomserver/storage/postgres/event_state_keys_table.go index 3a7cf03e3..762b3a1fc 100644 --- a/roomserver/storage/postgres/event_state_keys_table.go +++ b/roomserver/storage/postgres/event_state_keys_table.go @@ -111,9 +111,10 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID( } func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( - ctx context.Context, eventStateKeys []string, + ctx context.Context, txn *sql.Tx, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { - rows, err := s.bulkSelectEventStateKeyNIDStmt.QueryContext( + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventStateKeyNIDStmt) + rows, err := stmt.QueryContext( ctx, pq.StringArray(eventStateKeys), ) if err != nil { @@ -134,13 +135,14 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( } func (s *eventStateKeyStatements) BulkSelectEventStateKey( - ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { nIDs := make(pq.Int64Array, len(eventStateKeyNIDs)) for i := range eventStateKeyNIDs { nIDs[i] = int64(eventStateKeyNIDs[i]) } - rows, err := s.bulkSelectEventStateKeyStmt.QueryContext(ctx, nIDs) + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventStateKeyStmt) + rows, err := stmt.QueryContext(ctx, nIDs) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/event_types_table.go b/roomserver/storage/postgres/event_types_table.go index e558072a5..1d5de5822 100644 --- a/roomserver/storage/postgres/event_types_table.go +++ b/roomserver/storage/postgres/event_types_table.go @@ -133,9 +133,10 @@ func (s *eventTypeStatements) SelectEventTypeNID( } func (s *eventTypeStatements) BulkSelectEventTypeNID( - ctx context.Context, eventTypes []string, + ctx context.Context, txn *sql.Tx, eventTypes []string, ) (map[string]types.EventTypeNID, error) { - rows, err := s.bulkSelectEventTypeNIDStmt.QueryContext(ctx, pq.StringArray(eventTypes)) + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventTypeNIDStmt) + rows, err := stmt.QueryContext(ctx, pq.StringArray(eventTypes)) if err != nil { return nil, err } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 787cde73b..ad7600730 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -62,7 +62,7 @@ func (d *Database) EventTypeNIDs( } } if len(remaining) > 0 { - nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, remaining) + nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, nil, remaining) if err != nil { return nil, err } @@ -77,7 +77,7 @@ func (d *Database) EventTypeNIDs( func (d *Database) EventStateKeys( ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { - return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, eventStateKeyNIDs) + return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, eventStateKeyNIDs) } func (d *Database) EventStateKeyNIDs( @@ -93,7 +93,7 @@ func (d *Database) EventStateKeyNIDs( } } if len(remaining) > 0 { - nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, remaining) + nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, remaining) if err != nil { return nil, err } @@ -980,7 +980,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu } // we don't bother failing the request if we get asked for event types we don't know about, as all that would result in is no matches which // isn't a failure. - eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, eventTypes) + eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, nil, eventTypes) if err != nil { return nil, fmt.Errorf("GetBulkStateContent: failed to map event type nids: %w", err) } @@ -1000,7 +1000,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu } - eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, eventStateKeys) + eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, eventStateKeys) if err != nil { return nil, fmt.Errorf("GetBulkStateContent: failed to map state key nids: %w", err) } @@ -1076,7 +1076,7 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) stateKeyNIDs[i] = nid i++ } - nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, stateKeyNIDs) + nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go index 62fbce2d0..3b210fa22 100644 --- a/roomserver/storage/sqlite3/event_state_keys_table.go +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -112,15 +112,19 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID( } func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( - ctx context.Context, eventStateKeys []string, + ctx context.Context, txn *sql.Tx, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { iEventStateKeys := make([]interface{}, len(eventStateKeys)) for k, v := range eventStateKeys { iEventStateKeys[k] = v } selectOrig := strings.Replace(bulkSelectEventStateKeySQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeys)), 1) - - rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeys...) + selectPrep, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + stmt := sqlutil.TxStmt(txn, selectPrep) + rows, err := stmt.QueryContext(ctx, iEventStateKeys...) if err != nil { return nil, err } @@ -138,15 +142,19 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( } func (s *eventStateKeyStatements) BulkSelectEventStateKey( - ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs)) for k, v := range eventStateKeyNIDs { iEventStateKeyNIDs[k] = v } selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeyNIDs)), 1) - - rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...) + selectPrep, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + stmt := sqlutil.TxStmt(txn, selectPrep) + rows, err := stmt.QueryContext(ctx, iEventStateKeyNIDs...) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go index 22df3fb22..f2c9c42fe 100644 --- a/roomserver/storage/sqlite3/event_types_table.go +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -128,7 +128,7 @@ func (s *eventTypeStatements) SelectEventTypeNID( } func (s *eventTypeStatements) BulkSelectEventTypeNID( - ctx context.Context, eventTypes []string, + ctx context.Context, txn *sql.Tx, eventTypes []string, ) (map[string]types.EventTypeNID, error) { /////////////// iEventTypes := make([]interface{}, len(eventTypes)) @@ -140,9 +140,10 @@ func (s *eventTypeStatements) BulkSelectEventTypeNID( if err != nil { return nil, err } + stmt := sqlutil.TxStmt(txn, selectPrep) /////////////// - rows, err := selectPrep.QueryContext(ctx, iEventTypes...) + rows, err := stmt.QueryContext(ctx, iEventTypes...) if err != nil { return nil, err } diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index b4275744e..fed39b944 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -24,14 +24,14 @@ type EventJSON interface { type EventTypes interface { InsertEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error) SelectEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error) - BulkSelectEventTypeNID(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) + BulkSelectEventTypeNID(ctx context.Context, txn *sql.Tx, eventTypes []string) (map[string]types.EventTypeNID, error) } type EventStateKeys interface { InsertEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) SelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) - BulkSelectEventStateKeyNID(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) - BulkSelectEventStateKey(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) + BulkSelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) + BulkSelectEventStateKey(ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) } type Events interface {