Add possibility to query userIDs for senderKeys

This commit is contained in:
Till Faelligen 2023-06-06 12:20:43 +02:00
parent d4b1074fdf
commit 35142ec42f
No known key found for this signature in database
GPG key ID: ACCDC9606D472758
7 changed files with 168 additions and 23 deletions

View file

@ -192,6 +192,7 @@ type Database interface {
InsertUserRoomKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) InsertUserRoomKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error)
SelectUserRoomKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID) (key ed25519.PrivateKey, err error) SelectUserRoomKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID) (key ed25519.PrivateKey, err error)
SelectUserIDsForPublicKeys(ctx context.Context, publicKeys [][]byte) (map[string]string, error)
} }
type RoomDatabase interface { type RoomDatabase interface {

View file

@ -20,9 +20,12 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/sirupsen/logrus"
) )
const userRoomKeysSchema = ` const userRoomKeysSchema = `
@ -30,20 +33,24 @@ CREATE TABLE IF NOT EXISTS roomserver_user_room_keys (
user_nid INTEGER NOT NULL, user_nid INTEGER NOT NULL,
room_nid INTEGER NOT NULL, room_nid INTEGER NOT NULL,
pseudo_id_key BYTEA NOT NULL, pseudo_id_key BYTEA NOT NULL,
pseudo_id_pub_key BYTEA NOT NULL,
CONSTRAINT roomserver_user_room_keys_pk PRIMARY KEY (user_nid, room_nid) CONSTRAINT roomserver_user_room_keys_pk PRIMARY KEY (user_nid, room_nid)
); );
` `
const insertUserRoomKeySQL = ` const insertUserRoomKeySQL = `
INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_key) VALUES ($1, $2, $3) INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_key, pseudo_id_pub_key) VALUES ($1, $2, $3, $4)
ON CONFLICT ON CONSTRAINT roomserver_user_room_keys_pk DO UPDATE SET pseudo_id_key = roomserver_user_room_keys.pseudo_id_key ON CONFLICT ON CONSTRAINT roomserver_user_room_keys_pk DO UPDATE SET pseudo_id_key = roomserver_user_room_keys.pseudo_id_key
RETURNING (pseudo_id_key) RETURNING (pseudo_id_key)
` `
const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
const selectUserNIDsSQL = `SELECT user_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE pseudo_id_pub_key = ANY($1)`
type userRoomKeysStatements struct { type userRoomKeysStatements struct {
insertUserRoomKeyStmt *sql.Stmt insertUserRoomKeyStmt *sql.Stmt
selectUserRoomKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt
selectUserIDsStmt *sql.Stmt
} }
func CreateUserRoomKeysTable(db *sql.DB) error { func CreateUserRoomKeysTable(db *sql.DB) error {
@ -56,6 +63,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertUserRoomKeyStmt, insertUserRoomKeySQL}, {&s.insertUserRoomKeyStmt, insertUserRoomKeySQL},
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
{&s.selectUserIDsStmt, selectUserNIDsSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -67,7 +75,7 @@ func (s *userRoomKeysStatements) InsertUserRoomKey(
key ed25519.PrivateKey, key ed25519.PrivateKey,
) (result ed25519.PrivateKey, err error) { ) (result ed25519.PrivateKey, err error) {
stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomKeyStmt) stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomKeyStmt)
err = stmt.QueryRowContext(ctx, userNID, roomNID, key).Scan(&result) err = stmt.QueryRowContext(ctx, userNID, roomNID, key, key.Public()).Scan(&result)
return result, err return result, err
} }
@ -85,3 +93,28 @@ func (s *userRoomKeysStatements) SelectUserRoomKey(
} }
return result, err return result, err
} }
func (s *userRoomKeysStatements) BulkSelectUserNIDs(
ctx context.Context,
txn *sql.Tx,
senderKeys [][]byte,
) (map[string]types.EventStateKeyNID, error) {
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserIDsStmt)
logrus.Infof("%#v", senderKeys)
rows, err := stmt.QueryContext(ctx, pq.Array(senderKeys))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows")
result := make(map[string]types.EventStateKeyNID, len(senderKeys))
var publicKey []byte
var userNID types.EventStateKeyNID
for rows.Next() {
if err = rows.Scan(&userNID, &publicKey); err != nil {
return nil, err
}
result[string(publicKey)] = userNID
}
return result, rows.Err()
}

View file

@ -1619,6 +1619,35 @@ func (d *Database) SelectUserRoomKey(ctx context.Context, userNID types.EventSta
return return
} }
// SelectUserIDsForPublicKeys returns a map from senderKey -> userID
func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys [][]byte) (result map[string]string, err error) {
result = make(map[string]string, len(publicKeys))
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var sErr error
var userNIDKeyMap map[string]types.EventStateKeyNID
userNIDKeyMap, sErr = d.UserRoomKeyTable.BulkSelectUserNIDs(ctx, txn, publicKeys)
if sErr != nil {
return sErr
}
nids := make([]types.EventStateKeyNID, 0, len(userNIDKeyMap))
for _, nid := range userNIDKeyMap {
nids = append(nids, nid)
}
nidMAP, seErr := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, txn, nids)
if seErr != nil {
return err
}
for publicKey, userNID := range userNIDKeyMap {
userID := nidMAP[userNID]
result[publicKey] = userID
}
return nil
})
return result, err
}
// FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops // FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops
// it should live in this package! // it should live in this package!

View file

@ -115,11 +115,16 @@ func TestUserRoomKeys(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateRoomserverDatabase(t, dbType) db, close := mustCreateRoomserverDatabase(t, dbType)
defer close() defer close()
userNID := types.EventStateKeyNID(1)
roomNID := types.RoomNID(1) roomNID := types.RoomNID(1)
_, key, err := ed25519.GenerateKey(nil) _, key, err := ed25519.GenerateKey(nil)
assert.NoError(t, err) assert.NoError(t, err)
// insert dummy event state keys
dummy := test.NewUser(t)
userNID, err := db.GetOrCreateEventStateKeyNID(ctx, &dummy.ID)
assert.NoError(t, err)
gotKey, err := db.InsertUserRoomKey(ctx, userNID, roomNID, key) gotKey, err := db.InsertUserRoomKey(ctx, userNID, roomNID, key)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, gotKey, key) assert.Equal(t, gotKey, key)
@ -139,5 +144,9 @@ func TestUserRoomKeys(t *testing.T) {
gotKey, err = db.SelectUserRoomKey(context.Background(), userNID, 2) gotKey, err = db.SelectUserRoomKey(context.Background(), userNID, 2)
assert.NoError(t, err) assert.NoError(t, err)
assert.Nil(t, gotKey) assert.Nil(t, gotKey)
userIDs, err := db.SelectUserIDsForPublicKeys(ctx, [][]byte{key.Public().(ed25519.PublicKey)})
assert.NoError(t, err)
assert.NotNil(t, userIDs)
}) })
} }

View file

@ -19,7 +19,9 @@ import (
"crypto/ed25519" "crypto/ed25519"
"database/sql" "database/sql"
"errors" "errors"
"strings"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
@ -30,20 +32,24 @@ CREATE TABLE IF NOT EXISTS roomserver_user_room_keys (
user_nid INTEGER NOT NULL, user_nid INTEGER NOT NULL,
room_nid INTEGER NOT NULL, room_nid INTEGER NOT NULL,
pseudo_id_key TEXT NOT NULL, pseudo_id_key TEXT NOT NULL,
pseudo_id_pub_key TEXT NOT NULL,
CONSTRAINT roomserver_user_room_keys_pk PRIMARY KEY (user_nid, room_nid) CONSTRAINT roomserver_user_room_keys_pk PRIMARY KEY (user_nid, room_nid)
); );
` `
const insertUserRoomKeySQL = ` const insertUserRoomKeySQL = `
INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_key) VALUES ($1, $2, $3) INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_key, pseudo_id_pub_key) VALUES ($1, $2, $3, $4)
ON CONFLICT DO UPDATE SET pseudo_id_key = roomserver_user_room_keys.pseudo_id_key ON CONFLICT DO UPDATE SET pseudo_id_key = roomserver_user_room_keys.pseudo_id_key
RETURNING (pseudo_id_key) RETURNING (pseudo_id_key)
` `
const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
const selectUserNIDsSQL = `SELECT user_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE pseudo_id_pub_key IN ($1)`
type userRoomKeysStatements struct { type userRoomKeysStatements struct {
insertUserRoomKeyStmt *sql.Stmt insertUserRoomKeyStmt *sql.Stmt
selectUserRoomKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt
selectUserIDsStmt *sql.Stmt
} }
func CreateUserRoomKeysTable(db *sql.DB) error { func CreateUserRoomKeysTable(db *sql.DB) error {
@ -56,6 +62,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertUserRoomKeyStmt, insertUserRoomKeySQL}, {&s.insertUserRoomKeyStmt, insertUserRoomKeySQL},
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
{&s.selectUserIDsStmt, selectUserNIDsSQL}, //prepared at runtime
}.Prepare(db) }.Prepare(db)
} }
@ -67,7 +74,7 @@ func (s *userRoomKeysStatements) InsertUserRoomKey(
key ed25519.PrivateKey, key ed25519.PrivateKey,
) (result ed25519.PrivateKey, err error) { ) (result ed25519.PrivateKey, err error) {
stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomKeyStmt) stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomKeyStmt)
err = stmt.QueryRowContext(ctx, userNID, roomNID, key).Scan(&result) err = stmt.QueryRowContext(ctx, userNID, roomNID, key, key.Public()).Scan(&result)
return result, err return result, err
} }
@ -85,3 +92,41 @@ func (s *userRoomKeysStatements) SelectUserRoomKey(
} }
return result, err return result, err
} }
func (s *userRoomKeysStatements) BulkSelectUserNIDs(
ctx context.Context,
txn *sql.Tx,
senderKeys [][]byte,
) (map[string]types.EventStateKeyNID, error) {
selectSQL := strings.Replace(selectUserNIDsSQL, "($1)", sqlutil.QueryVariadic(len(senderKeys)), 1)
selectStmt, err := txn.Prepare(selectSQL)
if err != nil {
return nil, err
}
params := make([]interface{}, len(senderKeys))
for i := range senderKeys {
params[i] = senderKeys[i]
}
stmt := sqlutil.TxStmt(txn, selectStmt)
defer internal.CloseAndLogIfError(ctx, stmt, "failed to close transaction")
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows")
result := make(map[string]types.EventStateKeyNID, len(senderKeys))
var publicKey []byte
var userNID types.EventStateKeyNID
for rows.Next() {
if err = rows.Scan(&userNID, &publicKey); err != nil {
return nil, err
}
result[string(publicKey)] = userNID
}
return result, rows.Err()
}

View file

@ -188,6 +188,7 @@ type Purge interface {
type UserRoomKeys interface { type UserRoomKeys interface {
InsertUserRoomKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (ed25519.PrivateKey, error) InsertUserRoomKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (ed25519.PrivateKey, error)
SelectUserRoomKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error) SelectUserRoomKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error)
BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys [][]byte) (map[string]types.EventStateKeyNID, error)
} }
// StrippedEvent represents a stripped event for returning extracted content values. // StrippedEvent represents a stripped event for returning extracted content values.

View file

@ -3,6 +3,7 @@ package tables_test
import ( import (
"context" "context"
"crypto/ed25519" "crypto/ed25519"
"database/sql"
"testing" "testing"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
@ -15,7 +16,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func mustCreateUserRoomKeysTable(t *testing.T, dbType test.DBType) (tab tables.UserRoomKeys, close func()) { func mustCreateUserRoomKeysTable(t *testing.T, dbType test.DBType) (tab tables.UserRoomKeys, db *sql.DB, close func()) {
t.Helper() t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{ db, err := sqlutil.Open(&config.DatabaseOptions{
@ -34,35 +35,61 @@ func mustCreateUserRoomKeysTable(t *testing.T, dbType test.DBType) (tab tables.U
} }
assert.NoError(t, err) assert.NoError(t, err)
return tab, close return tab, db, close
} }
func TestUserRoomKeysTable(t *testing.T) { func TestUserRoomKeysTable(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, close := mustCreateUserRoomKeysTable(t, dbType) tab, db, close := mustCreateUserRoomKeysTable(t, dbType)
defer close() defer close()
userNID := types.EventStateKeyNID(1) userNID := types.EventStateKeyNID(1)
roomNID := types.RoomNID(1) roomNID := types.RoomNID(1)
_, key, err := ed25519.GenerateKey(nil) _, key, err := ed25519.GenerateKey(nil)
assert.NoError(t, err) assert.NoError(t, err)
gotKey, err := tab.InsertUserRoomKey(context.Background(), nil, userNID, roomNID, key)
assert.NoError(t, err)
assert.Equal(t, gotKey, key)
// again, this shouldn't result in an error, but return the existing key err = sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
_, key2, err := ed25519.GenerateKey(nil) var gotKey, key2, key3 ed25519.PrivateKey
assert.NoError(t, err) gotKey, err = tab.InsertUserRoomKey(context.Background(), txn, userNID, roomNID, key)
gotKey, err = tab.InsertUserRoomKey(context.Background(), nil, userNID, roomNID, key2) assert.NoError(t, err)
assert.NoError(t, err) assert.Equal(t, gotKey, key)
assert.Equal(t, gotKey, key)
gotKey, err = tab.SelectUserRoomKey(context.Background(), nil, userNID, roomNID) // again, this shouldn't result in an error, but return the existing key
assert.NoError(t, err) _, key2, err = ed25519.GenerateKey(nil)
assert.Equal(t, key, gotKey) assert.NoError(t, err)
gotKey, err = tab.InsertUserRoomKey(context.Background(), txn, userNID, roomNID, key2)
assert.NoError(t, err)
assert.Equal(t, gotKey, key)
// Key doesn't exist // add another user
gotKey, err = tab.SelectUserRoomKey(context.Background(), nil, userNID, 2) _, key3, err = ed25519.GenerateKey(nil)
assert.NoError(t, err)
userNID2 := types.EventStateKeyNID(2)
_, err = tab.InsertUserRoomKey(context.Background(), txn, userNID2, roomNID, key3)
assert.NoError(t, err)
gotKey, err = tab.SelectUserRoomKey(context.Background(), txn, userNID, roomNID)
assert.NoError(t, err)
assert.Equal(t, key, gotKey)
// Key doesn't exist
gotKey, err = tab.SelectUserRoomKey(context.Background(), txn, userNID, 2)
assert.NoError(t, err)
assert.Nil(t, gotKey)
// query user NIDs for senderKeys
var gotKeys map[string]types.EventStateKeyNID
gotKeys, err = tab.BulkSelectUserNIDs(context.Background(), txn, [][]byte{key.Public().(ed25519.PublicKey), key3.Public().(ed25519.PublicKey)})
assert.NoError(t, err)
assert.NotNil(t, gotKeys)
wantKeys := map[string]types.EventStateKeyNID{
string(key.Public().(ed25519.PublicKey)): userNID,
string(key3.Public().(ed25519.PublicKey)): userNID2,
}
assert.Equal(t, wantKeys, gotKeys)
return nil
})
assert.NoError(t, err) assert.NoError(t, err)
assert.Nil(t, gotKey)
}) })
} }