mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-16 10:33:11 -06:00
PR review comments, return existing key on conflict
This commit is contained in:
parent
acdb93b489
commit
0b70eead27
|
|
@ -190,7 +190,7 @@ type Database interface {
|
||||||
ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver,
|
ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver,
|
||||||
) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error)
|
) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error)
|
||||||
|
|
||||||
InsertUserRoomKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) error
|
InsertUserRoomKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error)
|
||||||
SelectUserRoomKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID) (key ed25519.PrivateKey, err error)
|
SelectUserRoomKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID) (key ed25519.PrivateKey, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,8 +18,8 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
|
@ -34,7 +34,11 @@ CREATE TABLE IF NOT EXISTS roomserver_user_room_keys (
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertUserRoomKeySQL = `INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_key) VALUES ($1, $2, $3)`
|
const insertUserRoomKeySQL = `
|
||||||
|
INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_key) VALUES ($1, $2, $3)
|
||||||
|
ON CONFLICT ON CONSTRAINT roomserver_user_room_keys_pk DO UPDATE SET pseudo_id_key = roomserver_user_room_keys.pseudo_id_key
|
||||||
|
RETURNING (pseudo_id_key)
|
||||||
|
`
|
||||||
const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
|
const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
|
||||||
|
|
||||||
type userRoomKeysStatements struct {
|
type userRoomKeysStatements struct {
|
||||||
|
|
@ -61,11 +65,10 @@ func (s *userRoomKeysStatements) InsertUserRoomKey(
|
||||||
userNID types.EventStateKeyNID,
|
userNID types.EventStateKeyNID,
|
||||||
roomNID types.RoomNID,
|
roomNID types.RoomNID,
|
||||||
key ed25519.PrivateKey,
|
key ed25519.PrivateKey,
|
||||||
) error {
|
) (result ed25519.PrivateKey, err error) {
|
||||||
stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomKeyStmt)
|
stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomKeyStmt)
|
||||||
defer internal.CloseAndLogIfError(ctx, stmt, "failed to close statement")
|
err = stmt.QueryRowContext(ctx, userNID, roomNID, key).Scan(&result)
|
||||||
_, err := stmt.ExecContext(ctx, userNID, roomNID, key)
|
return result, err
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *userRoomKeysStatements) SelectUserRoomKey(
|
func (s *userRoomKeysStatements) SelectUserRoomKey(
|
||||||
|
|
@ -75,8 +78,10 @@ func (s *userRoomKeysStatements) SelectUserRoomKey(
|
||||||
roomNID types.RoomNID,
|
roomNID types.RoomNID,
|
||||||
) (ed25519.PrivateKey, error) {
|
) (ed25519.PrivateKey, error) {
|
||||||
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomKeyStmt)
|
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomKeyStmt)
|
||||||
defer internal.CloseAndLogIfError(ctx, stmt, "failed to close statement")
|
|
||||||
var result ed25519.PrivateKey
|
var result ed25519.PrivateKey
|
||||||
err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result)
|
err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result)
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
|
|
@ -1592,21 +1593,25 @@ func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventS
|
||||||
}
|
}
|
||||||
|
|
||||||
// InsertUserRoomKey inserts a new user room key for the given user and room.
|
// InsertUserRoomKey inserts a new user room key for the given user and room.
|
||||||
// Returns an error if a database error occurred, also if the primary constraint was violated.
|
// Returns the newly inserted private key or an existing private key. If there is
|
||||||
func (d *Database) InsertUserRoomKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) error {
|
// an error talking to the database, returns that error.
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
func (d *Database) InsertUserRoomKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) {
|
||||||
return d.UserRoomKeyTable.InsertUserRoomKey(ctx, txn, userNID, roomNID, key)
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
var iErr error
|
||||||
|
result, iErr = d.UserRoomKeyTable.InsertUserRoomKey(ctx, txn, userNID, roomNID, key)
|
||||||
|
return iErr
|
||||||
})
|
})
|
||||||
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// SelectUserRoomKey queries the user room key for a given user.
|
// SelectUserRoomKey queries the users room private key.
|
||||||
// Returns the key and an error.
|
// If no key exists, returns no key and no error. Otherwise returns
|
||||||
// TODO: should we handle absent keys (sql.ErrNoRows) as non-fatal?
|
// the key and a database error, if any.
|
||||||
func (d *Database) SelectUserRoomKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID) (key ed25519.PrivateKey, err error) {
|
func (d *Database) SelectUserRoomKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID) (key ed25519.PrivateKey, err error) {
|
||||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
var sErr error
|
var sErr error
|
||||||
key, sErr = d.UserRoomKeyTable.SelectUserRoomKey(ctx, txn, userNID, roomNID)
|
key, sErr = d.UserRoomKeyTable.SelectUserRoomKey(ctx, txn, userNID, roomNID)
|
||||||
if sErr != nil {
|
if !errors.Is(sErr, sql.ErrNoRows) {
|
||||||
return sErr
|
return sErr
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
|
||||||
|
|
@ -120,19 +120,23 @@ func TestUserRoomKeys(t *testing.T) {
|
||||||
_, key, err := ed25519.GenerateKey(nil)
|
_, key, err := ed25519.GenerateKey(nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = db.InsertUserRoomKey(ctx, userNID, roomNID, key)
|
gotKey, err := db.InsertUserRoomKey(ctx, userNID, roomNID, key)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, gotKey, key)
|
||||||
|
|
||||||
// again, this should result in an error now, due to the primary key on userNID/roomNID
|
// again, this shouldn't result in an error, but return the existing key
|
||||||
err = db.InsertUserRoomKey(context.Background(), userNID, roomNID, key)
|
_, key2, err := ed25519.GenerateKey(nil)
|
||||||
assert.Error(t, err)
|
gotKey, err = db.InsertUserRoomKey(context.Background(), userNID, roomNID, key2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, gotKey, key)
|
||||||
|
|
||||||
gotKey, err := db.SelectUserRoomKey(context.Background(), userNID, roomNID)
|
gotKey, err = db.SelectUserRoomKey(context.Background(), userNID, roomNID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, key, gotKey)
|
assert.Equal(t, key, gotKey)
|
||||||
|
|
||||||
// Key doesn't exist
|
// Key doesn't exist, we shouldn't get anything back
|
||||||
_, err = db.SelectUserRoomKey(context.Background(), userNID, 2)
|
gotKey, err = db.SelectUserRoomKey(context.Background(), userNID, 2)
|
||||||
assert.Error(t, err)
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, gotKey)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -18,8 +18,8 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
|
@ -34,7 +34,11 @@ CREATE TABLE IF NOT EXISTS roomserver_user_room_keys (
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertUserRoomKeySQL = `INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_key) VALUES ($1, $2, $3)`
|
const insertUserRoomKeySQL = `
|
||||||
|
INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_key) VALUES ($1, $2, $3)
|
||||||
|
ON CONFLICT DO UPDATE SET pseudo_id_key = roomserver_user_room_keys.pseudo_id_key
|
||||||
|
RETURNING (pseudo_id_key)
|
||||||
|
`
|
||||||
const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
|
const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
|
||||||
|
|
||||||
type userRoomKeysStatements struct {
|
type userRoomKeysStatements struct {
|
||||||
|
|
@ -61,11 +65,10 @@ func (s *userRoomKeysStatements) InsertUserRoomKey(
|
||||||
userNID types.EventStateKeyNID,
|
userNID types.EventStateKeyNID,
|
||||||
roomNID types.RoomNID,
|
roomNID types.RoomNID,
|
||||||
key ed25519.PrivateKey,
|
key ed25519.PrivateKey,
|
||||||
) error {
|
) (result ed25519.PrivateKey, err error) {
|
||||||
stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomKeyStmt)
|
stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomKeyStmt)
|
||||||
defer internal.CloseAndLogIfError(ctx, stmt, "failed to close statement")
|
err = stmt.QueryRowContext(ctx, userNID, roomNID, key).Scan(&result)
|
||||||
_, err := stmt.ExecContext(ctx, userNID, roomNID, key)
|
return result, err
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *userRoomKeysStatements) SelectUserRoomKey(
|
func (s *userRoomKeysStatements) SelectUserRoomKey(
|
||||||
|
|
@ -75,8 +78,10 @@ func (s *userRoomKeysStatements) SelectUserRoomKey(
|
||||||
roomNID types.RoomNID,
|
roomNID types.RoomNID,
|
||||||
) (ed25519.PrivateKey, error) {
|
) (ed25519.PrivateKey, error) {
|
||||||
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomKeyStmt)
|
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomKeyStmt)
|
||||||
defer internal.CloseAndLogIfError(ctx, stmt, "failed to close statement")
|
|
||||||
var result ed25519.PrivateKey
|
var result ed25519.PrivateKey
|
||||||
err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result)
|
err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result)
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -186,7 +186,7 @@ type Purge interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserRoomKeys interface {
|
type UserRoomKeys interface {
|
||||||
InsertUserRoomKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) error
|
InsertUserRoomKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (ed25519.PrivateKey, error)
|
||||||
SelectUserRoomKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error)
|
SelectUserRoomKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -45,18 +45,23 @@ func TestUserRoomKeysTable(t *testing.T) {
|
||||||
roomNID := types.RoomNID(1)
|
roomNID := types.RoomNID(1)
|
||||||
_, key, err := ed25519.GenerateKey(nil)
|
_, key, err := ed25519.GenerateKey(nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
err = tab.InsertUserRoomKey(context.Background(), nil, userNID, roomNID, key)
|
gotKey, err := tab.InsertUserRoomKey(context.Background(), nil, userNID, roomNID, key)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
// again, this should result in an error now, due to the primary key on userNID/roomNID
|
assert.Equal(t, gotKey, key)
|
||||||
err = tab.InsertUserRoomKey(context.Background(), nil, userNID, roomNID, key)
|
|
||||||
assert.Error(t, err)
|
|
||||||
|
|
||||||
gotKey, err := tab.SelectUserRoomKey(context.Background(), nil, userNID, roomNID)
|
// again, this shouldn't result in an error, but return the existing key
|
||||||
|
_, key2, err := ed25519.GenerateKey(nil)
|
||||||
|
gotKey, err = tab.InsertUserRoomKey(context.Background(), nil, userNID, roomNID, key2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, gotKey, key)
|
||||||
|
|
||||||
|
gotKey, err = tab.SelectUserRoomKey(context.Background(), nil, userNID, roomNID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, key, gotKey)
|
assert.Equal(t, key, gotKey)
|
||||||
|
|
||||||
// Key doesn't exist
|
// Key doesn't exist
|
||||||
_, err = tab.SelectUserRoomKey(context.Background(), nil, userNID, 2)
|
gotKey, err = tab.SelectUserRoomKey(context.Background(), nil, userNID, 2)
|
||||||
assert.Error(t, err)
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, gotKey)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue