From 0b70eead27b705660e6826dfc0792833059a57cf Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Fri, 2 Jun 2023 16:41:11 +0200 Subject: [PATCH] PR review comments, return existing key on conflict --- roomserver/storage/interface.go | 2 +- .../storage/postgres/user_room_keys_table.go | 19 ++++++++++------- roomserver/storage/shared/storage.go | 21 ++++++++++++------- roomserver/storage/shared/storage_test.go | 20 +++++++++++------- .../storage/sqlite3/user_room_keys_table.go | 19 ++++++++++------- roomserver/storage/tables/interface.go | 2 +- .../tables/user_room_keys_table_test.go | 19 ++++++++++------- 7 files changed, 63 insertions(+), 39 deletions(-) diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index ac77965fc..1d9697fc5 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -190,7 +190,7 @@ type Database interface { ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, ) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error) - InsertUserRoomKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) error + InsertUserRoomKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) SelectUserRoomKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID) (key ed25519.PrivateKey, err error) } diff --git a/roomserver/storage/postgres/user_room_keys_table.go b/roomserver/storage/postgres/user_room_keys_table.go index 15b1bff63..e60f6b6a5 100644 --- a/roomserver/storage/postgres/user_room_keys_table.go +++ b/roomserver/storage/postgres/user_room_keys_table.go @@ -18,8 +18,8 @@ import ( "context" "crypto/ed25519" "database/sql" + "errors" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -34,7 +34,11 @@ CREATE TABLE IF NOT EXISTS roomserver_user_room_keys ( ); ` -const insertUserRoomKeySQL = `INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_key) VALUES ($1, $2, $3)` +const insertUserRoomKeySQL = ` + INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_key) VALUES ($1, $2, $3) + ON CONFLICT ON CONSTRAINT roomserver_user_room_keys_pk DO UPDATE SET pseudo_id_key = roomserver_user_room_keys.pseudo_id_key + RETURNING (pseudo_id_key) +` const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` type userRoomKeysStatements struct { @@ -61,11 +65,10 @@ func (s *userRoomKeysStatements) InsertUserRoomKey( userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey, -) error { +) (result ed25519.PrivateKey, err error) { stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomKeyStmt) - defer internal.CloseAndLogIfError(ctx, stmt, "failed to close statement") - _, err := stmt.ExecContext(ctx, userNID, roomNID, key) - return err + err = stmt.QueryRowContext(ctx, userNID, roomNID, key).Scan(&result) + return result, err } func (s *userRoomKeysStatements) SelectUserRoomKey( @@ -75,8 +78,10 @@ func (s *userRoomKeysStatements) SelectUserRoomKey( roomNID types.RoomNID, ) (ed25519.PrivateKey, error) { stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomKeyStmt) - defer internal.CloseAndLogIfError(ctx, stmt, "failed to close statement") var result ed25519.PrivateKey err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } return result, err } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index a1ea2ff59..8dd340e78 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -5,6 +5,7 @@ import ( "crypto/ed25519" "database/sql" "encoding/json" + "errors" "fmt" "sort" @@ -1592,21 +1593,25 @@ func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventS } // InsertUserRoomKey inserts a new user room key for the given user and room. -// Returns an error if a database error occurred, also if the primary constraint was violated. -func (d *Database) InsertUserRoomKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.UserRoomKeyTable.InsertUserRoomKey(ctx, txn, userNID, roomNID, key) +// Returns the newly inserted private key or an existing private key. If there is +// an error talking to the database, returns that error. +func (d *Database) InsertUserRoomKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var iErr error + result, iErr = d.UserRoomKeyTable.InsertUserRoomKey(ctx, txn, userNID, roomNID, key) + return iErr }) + return result, err } -// SelectUserRoomKey queries the user room key for a given user. -// Returns the key and an error. -// TODO: should we handle absent keys (sql.ErrNoRows) as non-fatal? +// SelectUserRoomKey queries the users room private key. +// If no key exists, returns no key and no error. Otherwise returns +// the key and a database error, if any. func (d *Database) SelectUserRoomKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID) (key ed25519.PrivateKey, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { var sErr error key, sErr = d.UserRoomKeyTable.SelectUserRoomKey(ctx, txn, userNID, roomNID) - if sErr != nil { + if !errors.Is(sErr, sql.ErrNoRows) { return sErr } return nil diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go index b05fafae3..63949a23f 100644 --- a/roomserver/storage/shared/storage_test.go +++ b/roomserver/storage/shared/storage_test.go @@ -120,19 +120,23 @@ func TestUserRoomKeys(t *testing.T) { _, key, err := ed25519.GenerateKey(nil) assert.NoError(t, err) - err = db.InsertUserRoomKey(ctx, userNID, roomNID, key) + gotKey, err := db.InsertUserRoomKey(ctx, userNID, roomNID, key) assert.NoError(t, err) + assert.Equal(t, gotKey, key) - // again, this should result in an error now, due to the primary key on userNID/roomNID - err = db.InsertUserRoomKey(context.Background(), userNID, roomNID, key) - assert.Error(t, err) + // again, this shouldn't result in an error, but return the existing key + _, key2, err := ed25519.GenerateKey(nil) + gotKey, err = db.InsertUserRoomKey(context.Background(), userNID, roomNID, key2) + assert.NoError(t, err) + assert.Equal(t, gotKey, key) - gotKey, err := db.SelectUserRoomKey(context.Background(), userNID, roomNID) + gotKey, err = db.SelectUserRoomKey(context.Background(), userNID, roomNID) assert.NoError(t, err) assert.Equal(t, key, gotKey) - // Key doesn't exist - _, err = db.SelectUserRoomKey(context.Background(), userNID, 2) - assert.Error(t, err) + // Key doesn't exist, we shouldn't get anything back + gotKey, err = db.SelectUserRoomKey(context.Background(), userNID, 2) + assert.NoError(t, err) + assert.Nil(t, gotKey) }) } diff --git a/roomserver/storage/sqlite3/user_room_keys_table.go b/roomserver/storage/sqlite3/user_room_keys_table.go index db83010bb..05ab6d01d 100644 --- a/roomserver/storage/sqlite3/user_room_keys_table.go +++ b/roomserver/storage/sqlite3/user_room_keys_table.go @@ -18,8 +18,8 @@ import ( "context" "crypto/ed25519" "database/sql" + "errors" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -34,7 +34,11 @@ CREATE TABLE IF NOT EXISTS roomserver_user_room_keys ( ); ` -const insertUserRoomKeySQL = `INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_key) VALUES ($1, $2, $3)` +const insertUserRoomKeySQL = ` + INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_key) VALUES ($1, $2, $3) + ON CONFLICT DO UPDATE SET pseudo_id_key = roomserver_user_room_keys.pseudo_id_key + RETURNING (pseudo_id_key) +` const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` type userRoomKeysStatements struct { @@ -61,11 +65,10 @@ func (s *userRoomKeysStatements) InsertUserRoomKey( userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey, -) error { +) (result ed25519.PrivateKey, err error) { stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomKeyStmt) - defer internal.CloseAndLogIfError(ctx, stmt, "failed to close statement") - _, err := stmt.ExecContext(ctx, userNID, roomNID, key) - return err + err = stmt.QueryRowContext(ctx, userNID, roomNID, key).Scan(&result) + return result, err } func (s *userRoomKeysStatements) SelectUserRoomKey( @@ -75,8 +78,10 @@ func (s *userRoomKeysStatements) SelectUserRoomKey( roomNID types.RoomNID, ) (ed25519.PrivateKey, error) { stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomKeyStmt) - defer internal.CloseAndLogIfError(ctx, stmt, "failed to close statement") var result ed25519.PrivateKey err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } return result, err } diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 2d875bffc..8b25617f7 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -186,7 +186,7 @@ type Purge interface { } type UserRoomKeys interface { - InsertUserRoomKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) error + InsertUserRoomKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (ed25519.PrivateKey, error) SelectUserRoomKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error) } diff --git a/roomserver/storage/tables/user_room_keys_table_test.go b/roomserver/storage/tables/user_room_keys_table_test.go index cc0550c5c..3e0c091ac 100644 --- a/roomserver/storage/tables/user_room_keys_table_test.go +++ b/roomserver/storage/tables/user_room_keys_table_test.go @@ -45,18 +45,23 @@ func TestUserRoomKeysTable(t *testing.T) { roomNID := types.RoomNID(1) _, key, err := ed25519.GenerateKey(nil) assert.NoError(t, err) - err = tab.InsertUserRoomKey(context.Background(), nil, userNID, roomNID, key) + gotKey, err := tab.InsertUserRoomKey(context.Background(), nil, userNID, roomNID, key) assert.NoError(t, err) - // again, this should result in an error now, due to the primary key on userNID/roomNID - err = tab.InsertUserRoomKey(context.Background(), nil, userNID, roomNID, key) - assert.Error(t, err) + assert.Equal(t, gotKey, key) - gotKey, err := tab.SelectUserRoomKey(context.Background(), nil, userNID, roomNID) + // again, this shouldn't result in an error, but return the existing key + _, key2, err := ed25519.GenerateKey(nil) + gotKey, err = tab.InsertUserRoomKey(context.Background(), nil, userNID, roomNID, key2) + assert.NoError(t, err) + assert.Equal(t, gotKey, key) + + gotKey, err = tab.SelectUserRoomKey(context.Background(), nil, userNID, roomNID) assert.NoError(t, err) assert.Equal(t, key, gotKey) // Key doesn't exist - _, err = tab.SelectUserRoomKey(context.Background(), nil, userNID, 2) - assert.Error(t, err) + gotKey, err = tab.SelectUserRoomKey(context.Background(), nil, userNID, 2) + assert.NoError(t, err) + assert.Nil(t, gotKey) }) }