diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 3c918a559..756dc32ad 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -20,6 +20,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/keyserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -77,9 +78,9 @@ type Database interface { // MarkDeviceListStale sets the stale bit for this user to isStale. MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error - CrossSigningKeysForUser(ctx context.Context, userID string) (api.CrossSigningKeyMap, error) - CrossSigningSigsForTarget(ctx context.Context, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (api.CrossSigningSigMap, error) + CrossSigningKeysForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) + CrossSigningSigsForTarget(ctx context.Context, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) - StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap api.CrossSigningKeyMap) error + StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error } diff --git a/keyserver/storage/postgres/cross_signing_keys_table.go b/keyserver/storage/postgres/cross_signing_keys_table.go index 63ad93035..1022157e8 100644 --- a/keyserver/storage/postgres/cross_signing_keys_table.go +++ b/keyserver/storage/postgres/cross_signing_keys_table.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -67,13 +66,13 @@ func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, erro func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( ctx context.Context, txn *sql.Tx, userID string, -) (r api.CrossSigningKeyMap, err error) { +) (r types.CrossSigningKeyMap, err error) { rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed") - r = api.CrossSigningKeyMap{} + r = types.CrossSigningKeyMap{} for rows.Next() { var keyTypeInt int16 var keyData gomatrixserverlib.Base64Bytes diff --git a/keyserver/storage/postgres/cross_signing_sigs_table.go b/keyserver/storage/postgres/cross_signing_sigs_table.go index cbdd5a5d1..677e7a48c 100644 --- a/keyserver/storage/postgres/cross_signing_sigs_table.go +++ b/keyserver/storage/postgres/cross_signing_sigs_table.go @@ -21,8 +21,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/keyserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -68,13 +68,13 @@ func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, erro func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID, -) (r api.CrossSigningSigMap, err error) { +) (r types.CrossSigningSigMap, err error) { rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, targetUserID, targetKeyID) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForTargetStmt: rows.close() failed") - r = api.CrossSigningSigMap{} + r = types.CrossSigningSigMap{} for rows.Next() { var userID string var keyID gomatrixserverlib.KeyID diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 0d01689c9..767242950 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/keyserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -158,17 +159,17 @@ func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isSta } // CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. -func (d *Database) CrossSigningKeysForUser(ctx context.Context, userID string) (api.CrossSigningKeyMap, error) { +func (d *Database) CrossSigningKeysForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) { return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) } // CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any. -func (d *Database) CrossSigningSigsForTarget(ctx context.Context, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (api.CrossSigningSigMap, error) { +func (d *Database) CrossSigningSigsForTarget(ctx context.Context, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) { return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, targetUserID, targetKeyID) } // StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user. -func (d *Database) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap api.CrossSigningKeyMap) error { +func (d *Database) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { for keyType, keyData := range keyMap { if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil { diff --git a/keyserver/storage/sqlite3/cross_signing_keys_table.go b/keyserver/storage/sqlite3/cross_signing_keys_table.go index 5aac9cac9..e103d9883 100644 --- a/keyserver/storage/sqlite3/cross_signing_keys_table.go +++ b/keyserver/storage/sqlite3/cross_signing_keys_table.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -66,13 +65,13 @@ func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( ctx context.Context, txn *sql.Tx, userID string, -) (r api.CrossSigningKeyMap, err error) { +) (r types.CrossSigningKeyMap, err error) { rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed") - r = api.CrossSigningKeyMap{} + r = types.CrossSigningKeyMap{} for rows.Next() { var keyTypeInt int16 var keyData gomatrixserverlib.Base64Bytes diff --git a/keyserver/storage/sqlite3/cross_signing_sigs_table.go b/keyserver/storage/sqlite3/cross_signing_sigs_table.go index 120921c42..aa7025831 100644 --- a/keyserver/storage/sqlite3/cross_signing_sigs_table.go +++ b/keyserver/storage/sqlite3/cross_signing_sigs_table.go @@ -21,8 +21,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/keyserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -67,13 +67,13 @@ func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID, -) (r api.CrossSigningSigMap, err error) { +) (r types.CrossSigningSigMap, err error) { rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, targetUserID, targetKeyID) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForTargetStmt: rows.close() failed") - r = api.CrossSigningSigMap{} + r = types.CrossSigningSigMap{} for rows.Next() { var userID string var keyID gomatrixserverlib.KeyID diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index e5e253877..0649b6803 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -20,6 +20,7 @@ import ( "encoding/json" "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/keyserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -54,11 +55,11 @@ type StaleDeviceLists interface { } type CrossSigningKeys interface { - SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r api.CrossSigningKeyMap, err error) + SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error) UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error } type CrossSigningSigs interface { - SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r api.CrossSigningSigMap, err error) + SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error) UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error } diff --git a/keyserver/types/storage.go b/keyserver/types/storage.go index e5f914b62..3480ec65f 100644 --- a/keyserver/types/storage.go +++ b/keyserver/types/storage.go @@ -31,3 +31,9 @@ var KeyTypeIntToPurpose = map[int16]gomatrixserverlib.CrossSigningKeyPurpose{ 2: gomatrixserverlib.CrossSigningKeyPurposeSelfSigning, 3: gomatrixserverlib.CrossSigningKeyPurposeUserSigning, } + +// Map of purpose -> public key +type CrossSigningKeyMap map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.Base64Bytes + +// Map of user ID -> key ID -> signature +type CrossSigningSigMap map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes diff --git a/roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go b/roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go index 42edbbc6f..8d0331748 100644 --- a/roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go +++ b/roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go @@ -93,6 +93,20 @@ func UpStateBlocksRefactor(tx *sql.Tx) error { } var newblocks types.StateBlockNIDs + if len(blocks) == 0 { + // some m.room.create events have a state snapshot but no state blocks at all which makes + // sense as there is no state before creation. The correct form should be to give the event + // in question a state snapshot NID of 0 to indicate 'no snapshot'. + // If we don't do this, we'll fail the assertions later on which try to ensure we didn't forget + // any snapshots. + _, err = tx.Exec( + `UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE event_type_nid = $1 AND event_state_key_nid = $2 AND state_snapshot_nid = $3`, + types.MRoomCreateNID, types.EmptyStateKeyNID, snapshot, + ) + if err != nil { + return fmt.Errorf("resetting create events snapshots to 0 errored: %s", err) + } + } for _, block := range blocks { if err = func() error { blockrows, berr := tx.Query(`SELECT event_nid FROM _roomserver_state_block WHERE state_block_nid = $1`, block) diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index a28d95fa5..b7fe7ee4f 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -571,6 +571,9 @@ func (s *eventStatements) SelectRoomNIDsForEventNIDs( } func eventNIDsAsArray(eventNIDs []types.EventNID) string { + if eventNIDs == nil { + eventNIDs = []types.EventNID{} // don't store 'null' in the DB + } b, _ := json.Marshal(eventNIDs) return string(b) } diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index a472437a3..58b0b5dc2 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -86,7 +86,7 @@ func (s *stateBlockStatements) BulkInsertStateData( entries types.StateEntries, ) (id types.StateBlockNID, err error) { entries = entries[:util.SortAndUnique(entries)] - var nids types.EventNIDs + nids := types.EventNIDs{} // zero slice to not store 'null' in the DB for _, e := range entries { nids = append(nids, e.EventNID) } diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index ad623d65a..040d99ae6 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -87,6 +87,9 @@ func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { func (s *stateSnapshotStatements) InsertState( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs, ) (stateNID types.StateSnapshotNID, err error) { + if stateBlockNIDs == nil { + stateBlockNIDs = []types.StateBlockNID{} // zero slice to not store 'null' in the DB + } stateBlockNIDs = stateBlockNIDs[:util.SortAndUnique(stateBlockNIDs)] stateBlockNIDsJSON, err := json.Marshal(stateBlockNIDs) if err != nil {