diff --git a/roomserver/storage/postgres/user_room_keys_table.go b/roomserver/storage/postgres/user_room_keys_table.go index 5c5b64396..630f5afba 100644 --- a/roomserver/storage/postgres/user_room_keys_table.go +++ b/roomserver/storage/postgres/user_room_keys_table.go @@ -45,7 +45,7 @@ const insertUserRoomPrivateKeySQL = ` const insertUserRoomPublicKeySQL = ` INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_pub_key) VALUES ($1, $2, $3) - ON CONFLICT ON CONSTRAINT roomserver_user_room_keys_pk DO UPDATE SET pseudo_id_pub_key = roomserver_user_room_keys.pseudo_id_pub_key + ON CONFLICT ON CONSTRAINT roomserver_user_room_keys_pk DO UPDATE SET pseudo_id_pub_key = $3 RETURNING (pseudo_id_pub_key) ` @@ -75,7 +75,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) { }.Prepare(db) } -func (s *userRoomKeysStatements) InsertUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) { +func (s *userRoomKeysStatements) InsertUserRoomPrivatePublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) { stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomPrivateKeyStmt) err = stmt.QueryRowContext(ctx, userNID, roomNID, key, key.Public()).Scan(&result) return result, err diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index e8761a684..4466b5b28 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1614,7 +1614,7 @@ func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventS }) } -// InsertUserRoomPrivateKey inserts a new user room key for the given user and room. +// InsertUserRoomPrivatePublicKey inserts a new user room key for the given user and room. // Returns the newly inserted private key or an existing private key. If there is // an error talking to the database, returns that error. func (d *Database) InsertUserRoomPrivatePublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) { @@ -1635,7 +1635,7 @@ func (d *Database) InsertUserRoomPrivatePublicKey(ctx context.Context, userID sp } var iErr error - result, iErr = d.UserRoomKeyTable.InsertUserRoomPrivateKey(ctx, txn, stateKeyNID, roomInfo.RoomNID, key) + result, iErr = d.UserRoomKeyTable.InsertUserRoomPrivatePublicKey(ctx, txn, stateKeyNID, roomInfo.RoomNID, key) return iErr }) return result, err diff --git a/roomserver/storage/sqlite3/user_room_keys_table.go b/roomserver/storage/sqlite3/user_room_keys_table.go index d334c616d..8af57ea0e 100644 --- a/roomserver/storage/sqlite3/user_room_keys_table.go +++ b/roomserver/storage/sqlite3/user_room_keys_table.go @@ -45,7 +45,7 @@ const insertUserRoomKeySQL = ` const insertUserRoomPublicKeySQL = ` INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_pub_key) VALUES ($1, $2, $3) - ON CONFLICT DO UPDATE SET pseudo_id_pub_key = roomserver_user_room_keys.pseudo_id_pub_key + ON CONFLICT DO UPDATE SET pseudo_id_pub_key = $3 RETURNING (pseudo_id_pub_key) ` @@ -75,7 +75,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) { }.Prepare(db) } -func (s *userRoomKeysStatements) InsertUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) { +func (s *userRoomKeysStatements) InsertUserRoomPrivatePublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) { stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomPrivateKeyStmt) err = stmt.QueryRowContext(ctx, userNID, roomNID, key, key.Public()).Scan(&result) return result, err diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index b2578021b..1f1e433af 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -186,7 +186,7 @@ type Purge interface { } type UserRoomKeys interface { - InsertUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (ed25519.PrivateKey, error) + InsertUserRoomPrivatePublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (ed25519.PrivateKey, error) InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (ed25519.PublicKey, error) SelectUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) diff --git a/roomserver/storage/tables/user_room_keys_table_test.go b/roomserver/storage/tables/user_room_keys_table_test.go index 5b4e53a33..284309481 100644 --- a/roomserver/storage/tables/user_room_keys_table_test.go +++ b/roomserver/storage/tables/user_room_keys_table_test.go @@ -50,14 +50,14 @@ func TestUserRoomKeysTable(t *testing.T) { err = sqlutil.WithTransaction(db, func(txn *sql.Tx) error { var gotKey, key2, key3 ed25519.PrivateKey - gotKey, err = tab.InsertUserRoomPrivateKey(context.Background(), txn, userNID, roomNID, key) + gotKey, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID, roomNID, key) assert.NoError(t, err) assert.Equal(t, gotKey, key) // again, this shouldn't result in an error, but return the existing key _, key2, err = ed25519.GenerateKey(nil) assert.NoError(t, err) - gotKey, err = tab.InsertUserRoomPrivateKey(context.Background(), txn, userNID, roomNID, key2) + gotKey, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID, roomNID, key2) assert.NoError(t, err) assert.Equal(t, gotKey, key) @@ -65,13 +65,19 @@ func TestUserRoomKeysTable(t *testing.T) { _, key3, err = ed25519.GenerateKey(nil) assert.NoError(t, err) userNID2 := types.EventStateKeyNID(2) - _, err = tab.InsertUserRoomPrivateKey(context.Background(), txn, userNID2, roomNID, key3) + _, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID2, roomNID, key3) assert.NoError(t, err) gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, roomNID) assert.NoError(t, err) assert.Equal(t, key, gotKey) + // try to update an existing key, this should only be done for users NOT on this homeserver + var gotPubKey ed25519.PublicKey + gotPubKey, err = tab.InsertUserRoomPublicKey(context.Background(), txn, userNID, roomNID, key2.Public().(ed25519.PublicKey)) + assert.NoError(t, err) + assert.Equal(t, key2.Public(), gotPubKey) + // Key doesn't exist gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, 2) assert.NoError(t, err) @@ -80,7 +86,7 @@ func TestUserRoomKeysTable(t *testing.T) { // query user NIDs for senderKeys var gotKeys map[string]types.UserRoomKeyPair query := map[types.RoomNID][]ed25519.PublicKey{ - roomNID: {key.Public().(ed25519.PublicKey), key3.Public().(ed25519.PublicKey)}, + roomNID: {key2.Public().(ed25519.PublicKey), key3.Public().(ed25519.PublicKey)}, types.RoomNID(2): {key.Public().(ed25519.PublicKey), key3.Public().(ed25519.PublicKey)}, // doesn't exist } gotKeys, err = tab.BulkSelectUserNIDs(context.Background(), txn, query) @@ -88,7 +94,7 @@ func TestUserRoomKeysTable(t *testing.T) { assert.NotNil(t, gotKeys) wantKeys := map[string]types.UserRoomKeyPair{ - string(key.Public().(ed25519.PublicKey)): {RoomNID: roomNID, EventStateKeyNID: userNID}, + string(key2.Public().(ed25519.PublicKey)): {RoomNID: roomNID, EventStateKeyNID: userNID}, string(key3.Public().(ed25519.PublicKey)): {RoomNID: roomNID, EventStateKeyNID: userNID2}, } assert.Equal(t, wantKeys, gotKeys)