diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 699b42500..e915767cb 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -190,8 +190,9 @@ type Database interface { ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, ) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error) - InsertUserRoomKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) - SelectUserRoomKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID) (key ed25519.PrivateKey, err error) + InsertUserRoomPrivateKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) + SelectUserRoomPrivateKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID) (key ed25519.PrivateKey, err error) + InsertUserRoomPublicKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys [][]byte) (map[string]string, error) } diff --git a/roomserver/storage/postgres/user_room_keys_table.go b/roomserver/storage/postgres/user_room_keys_table.go index dcd2a08ae..f969cd246 100644 --- a/roomserver/storage/postgres/user_room_keys_table.go +++ b/roomserver/storage/postgres/user_room_keys_table.go @@ -31,25 +31,33 @@ const userRoomKeysSchema = ` CREATE TABLE IF NOT EXISTS roomserver_user_room_keys ( user_nid INTEGER NOT NULL, room_nid INTEGER NOT NULL, - pseudo_id_key BYTEA NOT NULL, + pseudo_id_key BYTEA NULL, -- may be null for users not local to the server pseudo_id_pub_key BYTEA NOT NULL, CONSTRAINT roomserver_user_room_keys_pk PRIMARY KEY (user_nid, room_nid) ); ` -const insertUserRoomKeySQL = ` +const insertUserRoomPrivateKeySQL = ` INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_key, pseudo_id_pub_key) VALUES ($1, $2, $3, $4) ON CONFLICT ON CONSTRAINT roomserver_user_room_keys_pk DO UPDATE SET pseudo_id_key = roomserver_user_room_keys.pseudo_id_key RETURNING (pseudo_id_key) ` + +const insertUserRoomPublicKeySQL = ` + INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_pub_key) VALUES ($1, $2, $3) + ON CONFLICT ON CONSTRAINT roomserver_user_room_keys_pk DO UPDATE SET pseudo_id_pub_key = roomserver_user_room_keys.pseudo_id_pub_key + RETURNING (pseudo_id_pub_key) +` + const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` const selectUserNIDsSQL = `SELECT user_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE pseudo_id_pub_key = ANY($1)` type userRoomKeysStatements struct { - insertUserRoomKeyStmt *sql.Stmt - selectUserRoomKeyStmt *sql.Stmt - selectUserIDsStmt *sql.Stmt + insertUserRoomPrivateKeyStmt *sql.Stmt + insertUserRoomPublicKeyStmt *sql.Stmt + selectUserRoomKeyStmt *sql.Stmt + selectUserNIDsStmt *sql.Stmt } func CreateUserRoomKeysTable(db *sql.DB) error { @@ -60,25 +68,26 @@ func CreateUserRoomKeysTable(db *sql.DB) error { func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) { s := &userRoomKeysStatements{} return s, sqlutil.StatementList{ - {&s.insertUserRoomKeyStmt, insertUserRoomKeySQL}, + {&s.insertUserRoomPrivateKeyStmt, insertUserRoomPrivateKeySQL}, + {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, - {&s.selectUserIDsStmt, selectUserNIDsSQL}, + {&s.selectUserNIDsStmt, selectUserNIDsSQL}, }.Prepare(db) } -func (s *userRoomKeysStatements) InsertUserRoomKey( - ctx context.Context, - txn *sql.Tx, - userNID types.EventStateKeyNID, - roomNID types.RoomNID, - key ed25519.PrivateKey, -) (result ed25519.PrivateKey, err error) { - stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomKeyStmt) +func (s *userRoomKeysStatements) InsertUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomPrivateKeyStmt) err = stmt.QueryRowContext(ctx, userNID, roomNID, key, key.Public()).Scan(&result) return result, err } -func (s *userRoomKeysStatements) SelectUserRoomKey( +func (s *userRoomKeysStatements) InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomPublicKeyStmt) + err = stmt.QueryRowContext(ctx, userNID, roomNID, key).Scan(&result) + return result, err +} + +func (s *userRoomKeysStatements) SelectUserRoomPrivateKey( ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, @@ -98,7 +107,7 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs( txn *sql.Tx, senderKeys [][]byte, ) (map[string]types.EventStateKeyNID, error) { - stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserIDsStmt) + stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserNIDsStmt) rows, err := stmt.QueryContext(ctx, pq.Array(senderKeys)) if err != nil { diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 71bca2de5..92b69dfcb 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1592,25 +1592,37 @@ func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventS }) } -// InsertUserRoomKey inserts a new user room key for the given user and room. +// InsertUserRoomPrivateKey inserts a new user room key for the given user and room. // Returns the newly inserted private key or an existing private key. If there is // an error talking to the database, returns that error. -func (d *Database) InsertUserRoomKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) { +func (d *Database) InsertUserRoomPrivateKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { var iErr error - result, iErr = d.UserRoomKeyTable.InsertUserRoomKey(ctx, txn, userNID, roomNID, key) + result, iErr = d.UserRoomKeyTable.InsertUserRoomPrivateKey(ctx, txn, userNID, roomNID, key) return iErr }) return result, err } -// SelectUserRoomKey queries the users room private key. +// InsertUserRoomPrivateKey inserts a new user room key for the given user and room. +// Returns the newly inserted private key or an existing private key. If there is +// an error talking to the database, returns that error. +func (d *Database) InsertUserRoomPublicKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var iErr error + result, iErr = d.UserRoomKeyTable.InsertUserRoomPublicKey(ctx, txn, userNID, roomNID, key) + return iErr + }) + return result, err +} + +// SelectUserRoomPrivateKey queries the users room private key. // If no key exists, returns no key and no error. Otherwise returns // the key and a database error, if any. -func (d *Database) SelectUserRoomKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID) (key ed25519.PrivateKey, err error) { +func (d *Database) SelectUserRoomPrivateKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID) (key ed25519.PrivateKey, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { var sErr error - key, sErr = d.UserRoomKeyTable.SelectUserRoomKey(ctx, txn, userNID, roomNID) + key, sErr = d.UserRoomKeyTable.SelectUserRoomPrivateKey(ctx, txn, userNID, roomNID) if !errors.Is(sErr, sql.ErrNoRows) { return sErr } diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go index c4dcabc98..8c60d6821 100644 --- a/roomserver/storage/shared/storage_test.go +++ b/roomserver/storage/shared/storage_test.go @@ -9,6 +9,7 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/roomserver/types" "github.com/stretchr/testify/assert" + ed255192 "golang.org/x/crypto/ed25519" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/postgres" @@ -125,28 +126,37 @@ func TestUserRoomKeys(t *testing.T) { userNID, err := db.GetOrCreateEventStateKeyNID(ctx, &dummy.ID) assert.NoError(t, err) - gotKey, err := db.InsertUserRoomKey(ctx, userNID, roomNID, key) + gotKey, err := db.InsertUserRoomPrivateKey(ctx, userNID, roomNID, key) assert.NoError(t, err) assert.Equal(t, gotKey, key) // again, this shouldn't result in an error, but return the existing key _, key2, err := ed25519.GenerateKey(nil) assert.NoError(t, err) - gotKey, err = db.InsertUserRoomKey(context.Background(), userNID, roomNID, key2) + gotKey, err = db.InsertUserRoomPrivateKey(context.Background(), userNID, roomNID, key2) assert.NoError(t, err) assert.Equal(t, gotKey, key) - gotKey, err = db.SelectUserRoomKey(context.Background(), userNID, roomNID) + gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), userNID, roomNID) assert.NoError(t, err) assert.Equal(t, key, gotKey) // Key doesn't exist, we shouldn't get anything back - gotKey, err = db.SelectUserRoomKey(context.Background(), userNID, 2) + gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), userNID, 2) assert.NoError(t, err) assert.Nil(t, gotKey) userIDs, err := db.SelectUserIDsForPublicKeys(ctx, [][]byte{key.Public().(ed25519.PublicKey)}) assert.NoError(t, err) assert.NotNil(t, userIDs) + + // insert key that came in over federation + var gotPublicKey, key4 ed255192.PublicKey + key4, _, err = ed25519.GenerateKey(nil) + assert.NoError(t, err) + gotPublicKey, err = db.InsertUserRoomPublicKey(context.Background(), userNID, 2, key4) + assert.NoError(t, err) + assert.Equal(t, key4, gotPublicKey) + }) } diff --git a/roomserver/storage/sqlite3/user_room_keys_table.go b/roomserver/storage/sqlite3/user_room_keys_table.go index 300ae8411..155391a59 100644 --- a/roomserver/storage/sqlite3/user_room_keys_table.go +++ b/roomserver/storage/sqlite3/user_room_keys_table.go @@ -31,7 +31,7 @@ const userRoomKeysSchema = ` CREATE TABLE IF NOT EXISTS roomserver_user_room_keys ( user_nid INTEGER NOT NULL, room_nid INTEGER NOT NULL, - pseudo_id_key TEXT NOT NULL, + pseudo_id_key TEXT NULL, -- may be null for users not local to the server pseudo_id_pub_key TEXT NOT NULL, CONSTRAINT roomserver_user_room_keys_pk PRIMARY KEY (user_nid, room_nid) ); @@ -42,14 +42,22 @@ const insertUserRoomKeySQL = ` ON CONFLICT DO UPDATE SET pseudo_id_key = roomserver_user_room_keys.pseudo_id_key RETURNING (pseudo_id_key) ` + +const insertUserRoomPublicKeySQL = ` + INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_pub_key) VALUES ($1, $2, $3) + ON CONFLICT DO UPDATE SET pseudo_id_pub_key = roomserver_user_room_keys.pseudo_id_pub_key + RETURNING (pseudo_id_pub_key) +` + const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` const selectUserNIDsSQL = `SELECT user_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE pseudo_id_pub_key IN ($1)` type userRoomKeysStatements struct { - insertUserRoomKeyStmt *sql.Stmt - selectUserRoomKeyStmt *sql.Stmt - selectUserIDsStmt *sql.Stmt + insertUserRoomPrivateKeyStmt *sql.Stmt + insertUserRoomPublicKeyStmt *sql.Stmt + selectUserRoomKeyStmt *sql.Stmt + selectUserNIDsStmt *sql.Stmt } func CreateUserRoomKeysTable(db *sql.DB) error { @@ -60,25 +68,26 @@ func CreateUserRoomKeysTable(db *sql.DB) error { func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) { s := &userRoomKeysStatements{} return s, sqlutil.StatementList{ - {&s.insertUserRoomKeyStmt, insertUserRoomKeySQL}, + {&s.insertUserRoomPrivateKeyStmt, insertUserRoomKeySQL}, + {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, - {&s.selectUserIDsStmt, selectUserNIDsSQL}, //prepared at runtime + {&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime }.Prepare(db) } -func (s *userRoomKeysStatements) InsertUserRoomKey( - ctx context.Context, - txn *sql.Tx, - userNID types.EventStateKeyNID, - roomNID types.RoomNID, - key ed25519.PrivateKey, -) (result ed25519.PrivateKey, err error) { - stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomKeyStmt) +func (s *userRoomKeysStatements) InsertUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomPrivateKeyStmt) err = stmt.QueryRowContext(ctx, userNID, roomNID, key, key.Public()).Scan(&result) return result, err } -func (s *userRoomKeysStatements) SelectUserRoomKey( +func (s *userRoomKeysStatements) InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomPublicKeyStmt) + err = stmt.QueryRowContext(ctx, userNID, roomNID, key).Scan(&result) + return result, err +} + +func (s *userRoomKeysStatements) SelectUserRoomPrivateKey( ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 98e27b858..3064628ca 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -186,8 +186,9 @@ type Purge interface { } type UserRoomKeys interface { - InsertUserRoomKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (ed25519.PrivateKey, error) - SelectUserRoomKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error) + InsertUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (ed25519.PrivateKey, error) + InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (ed25519.PublicKey, error) + SelectUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys [][]byte) (map[string]types.EventStateKeyNID, error) } diff --git a/roomserver/storage/tables/user_room_keys_table_test.go b/roomserver/storage/tables/user_room_keys_table_test.go index f515460e3..253f2b427 100644 --- a/roomserver/storage/tables/user_room_keys_table_test.go +++ b/roomserver/storage/tables/user_room_keys_table_test.go @@ -14,6 +14,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/stretchr/testify/assert" + ed255192 "golang.org/x/crypto/ed25519" ) func mustCreateUserRoomKeysTable(t *testing.T, dbType test.DBType) (tab tables.UserRoomKeys, db *sql.DB, close func()) { @@ -49,14 +50,14 @@ func TestUserRoomKeysTable(t *testing.T) { err = sqlutil.WithTransaction(db, func(txn *sql.Tx) error { var gotKey, key2, key3 ed25519.PrivateKey - gotKey, err = tab.InsertUserRoomKey(context.Background(), txn, userNID, roomNID, key) + gotKey, err = tab.InsertUserRoomPrivateKey(context.Background(), txn, userNID, roomNID, key) assert.NoError(t, err) assert.Equal(t, gotKey, key) // again, this shouldn't result in an error, but return the existing key _, key2, err = ed25519.GenerateKey(nil) assert.NoError(t, err) - gotKey, err = tab.InsertUserRoomKey(context.Background(), txn, userNID, roomNID, key2) + gotKey, err = tab.InsertUserRoomPrivateKey(context.Background(), txn, userNID, roomNID, key2) assert.NoError(t, err) assert.Equal(t, gotKey, key) @@ -64,15 +65,15 @@ func TestUserRoomKeysTable(t *testing.T) { _, key3, err = ed25519.GenerateKey(nil) assert.NoError(t, err) userNID2 := types.EventStateKeyNID(2) - _, err = tab.InsertUserRoomKey(context.Background(), txn, userNID2, roomNID, key3) + _, err = tab.InsertUserRoomPrivateKey(context.Background(), txn, userNID2, roomNID, key3) assert.NoError(t, err) - gotKey, err = tab.SelectUserRoomKey(context.Background(), txn, userNID, roomNID) + gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, roomNID) assert.NoError(t, err) assert.Equal(t, key, gotKey) // Key doesn't exist - gotKey, err = tab.SelectUserRoomKey(context.Background(), txn, userNID, 2) + gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, 2) assert.NoError(t, err) assert.Nil(t, gotKey) @@ -87,6 +88,15 @@ func TestUserRoomKeysTable(t *testing.T) { string(key3.Public().(ed25519.PublicKey)): userNID2, } assert.Equal(t, wantKeys, gotKeys) + + // insert key that came in over federation + var gotPublicKey, key4 ed255192.PublicKey + key4, _, err = ed25519.GenerateKey(nil) + assert.NoError(t, err) + gotPublicKey, err = tab.InsertUserRoomPublicKey(context.Background(), txn, userNID, 2, key4) + assert.NoError(t, err) + assert.Equal(t, key4, gotPublicKey) + return nil }) assert.NoError(t, err)