Add roomID/roomNID to query for public keys

This commit is contained in:
Till Faelligen 2023-06-07 12:35:34 +02:00
parent f2f0bdd2f4
commit 5dac4ae017
No known key found for this signature in database
GPG key ID: ACCDC9606D472758
8 changed files with 112 additions and 55 deletions

View file

@ -194,10 +194,10 @@ type Database interface {
ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver,
) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error) ) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error)
InsertUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) InsertUserRoomPrivatePublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error)
SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error) SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error)
InsertUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) InsertUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PublicKey) (result ed25519.PublicKey, err error)
SelectUserIDsForPublicKeys(ctx context.Context, publicKeys [][]byte) (map[string]string, error) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (map[spec.RoomID]map[string]string, error)
} }
type RoomDatabase interface { type RoomDatabase interface {

View file

@ -51,7 +51,7 @@ const insertUserRoomPublicKeySQL = `
const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
const selectUserNIDsSQL = `SELECT user_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE pseudo_id_pub_key = ANY($1)` const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE pseudo_id_pub_key = ANY($1)`
type userRoomKeysStatements struct { type userRoomKeysStatements struct {
insertUserRoomPrivateKeyStmt *sql.Stmt insertUserRoomPrivateKeyStmt *sql.Stmt
@ -102,27 +102,31 @@ func (s *userRoomKeysStatements) SelectUserRoomPrivateKey(
return result, err return result, err
} }
func (s *userRoomKeysStatements) BulkSelectUserNIDs( func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) {
ctx context.Context,
txn *sql.Tx,
senderKeys [][]byte,
) (map[string]types.EventStateKeyNID, error) {
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserNIDsStmt) stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserNIDsStmt)
rows, err := stmt.QueryContext(ctx, pq.Array(senderKeys)) roomNIDs := make([]types.RoomNID, 0, len(senderKeys))
var senders [][]byte
for roomNID := range senderKeys {
roomNIDs = append(roomNIDs, roomNID)
for _, key := range senderKeys[roomNID] {
senders = append(senders, key)
}
}
rows, err := stmt.QueryContext(ctx, pq.Array(senders))
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows") defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows")
result := make(map[string]types.EventStateKeyNID, len(senderKeys)) result := make(map[string]types.UserRoomKeyPair, len(senders)+len(roomNIDs))
var publicKey []byte var publicKey []byte
var userNID types.EventStateKeyNID userRoomKeyPair := types.UserRoomKeyPair{}
for rows.Next() { for rows.Next() {
if err = rows.Scan(&userNID, &publicKey); err != nil { if err = rows.Scan(&userRoomKeyPair.EventStateKeyNID, &userRoomKeyPair.RoomNID, &publicKey); err != nil {
return nil, err return nil, err
} }
result[string(publicKey)] = userNID result[string(publicKey)] = userRoomKeyPair
} }
return result, rows.Err() return result, rows.Err()
} }

View file

@ -13,6 +13,7 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
@ -1616,7 +1617,7 @@ func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventS
// InsertUserRoomPrivateKey inserts a new user room key for the given user and room. // InsertUserRoomPrivateKey inserts a new user room key for the given user and room.
// Returns the newly inserted private key or an existing private key. If there is // Returns the newly inserted private key or an existing private key. If there is
// an error talking to the database, returns that error. // an error talking to the database, returns that error.
func (d *Database) InsertUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) { func (d *Database) InsertUserRoomPrivatePublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) {
uID := userID.String() uID := userID.String()
stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID}) stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID})
if sErr != nil { if sErr != nil {
@ -1696,28 +1697,57 @@ func (d *Database) SelectUserRoomPrivateKey(ctx context.Context, userID spec.Use
return return
} }
// SelectUserIDsForPublicKeys returns a map from senderKey -> userID // SelectUserIDsForPublicKeys returns a map from roomID -> map from senderKey -> userID
func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys [][]byte) (result map[string]string, err error) { func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (result map[spec.RoomID]map[string]string, err error) {
result = make(map[string]string, len(publicKeys)) result = make(map[spec.RoomID]map[string]string, len(publicKeys))
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var sErr error
var userNIDKeyMap map[string]types.EventStateKeyNID // map all roomIDs to roomNIDs
userNIDKeyMap, sErr = d.UserRoomKeyTable.BulkSelectUserNIDs(ctx, txn, publicKeys) query := make(map[types.RoomNID][]ed25519.PublicKey)
rooms := make(map[types.RoomNID]spec.RoomID)
for roomID, keys := range publicKeys {
roomNID, ok := d.Cache.GetRoomServerRoomNID(roomID.String())
if !ok {
roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String())
if rErr != nil {
return rErr
}
if roomInfo == nil {
logrus.Warnf("missing room info for %s, there will be missing users in the response", roomID.String())
continue
}
roomNID = roomInfo.RoomNID
}
query[roomNID] = keys
rooms[roomNID] = roomID
}
// get the user room key pars
userRoomKeyPairMap, sErr := d.UserRoomKeyTable.BulkSelectUserNIDs(ctx, txn, query)
if sErr != nil { if sErr != nil {
return sErr return sErr
} }
nids := make([]types.EventStateKeyNID, 0, len(userNIDKeyMap)) nids := make([]types.EventStateKeyNID, 0, len(userRoomKeyPairMap))
for _, nid := range userNIDKeyMap { for _, nid := range userRoomKeyPairMap {
nids = append(nids, nid) nids = append(nids, nid.EventStateKeyNID)
} }
// get the userIDs
nidMAP, seErr := d.EventStateKeys(ctx, nids) nidMAP, seErr := d.EventStateKeys(ctx, nids)
if seErr != nil { if seErr != nil {
return seErr return seErr
} }
for publicKey, userNID := range userNIDKeyMap { // build the result map (roomID -> map publicKey -> userID)
userID := nidMAP[userNID] for publicKey, userRoomKeyPair := range userRoomKeyPairMap {
result[publicKey] = userID userID := nidMAP[userRoomKeyPair.EventStateKeyNID]
roomID := rooms[userRoomKeyPair.RoomNID]
resMap, exists := result[roomID]
if !exists {
resMap = map[string]string{}
}
resMap[publicKey] = userID
result[roomID] = resMap
} }
return nil return nil

View file

@ -148,14 +148,14 @@ func TestUserRoomKeys(t *testing.T) {
_, key, err := ed25519.GenerateKey(nil) _, key, err := ed25519.GenerateKey(nil)
assert.NoError(t, err) assert.NoError(t, err)
gotKey, err := db.InsertUserRoomPrivateKey(ctx, *userID, *roomID, key) gotKey, err := db.InsertUserRoomPrivatePublicKey(ctx, *userID, *roomID, key)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, gotKey, key) assert.Equal(t, gotKey, key)
// again, this shouldn't result in an error, but return the existing key // again, this shouldn't result in an error, but return the existing key
_, key2, err := ed25519.GenerateKey(nil) _, key2, err := ed25519.GenerateKey(nil)
assert.NoError(t, err) assert.NoError(t, err)
gotKey, err = db.InsertUserRoomPrivateKey(ctx, *userID, *roomID, key2) gotKey, err = db.InsertUserRoomPrivatePublicKey(ctx, *userID, *roomID, key2)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, gotKey, key) assert.Equal(t, gotKey, key)
@ -169,9 +169,18 @@ func TestUserRoomKeys(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Nil(t, gotKey) assert.Nil(t, gotKey)
userIDs, err := db.SelectUserIDsForPublicKeys(ctx, [][]byte{key.Public().(ed25519.PublicKey)}) queryUserIDs := map[spec.RoomID][]ed25519.PublicKey{
*roomID: {key.Public().(ed25519.PublicKey)},
}
userIDs, err := db.SelectUserIDsForPublicKeys(ctx, queryUserIDs)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, userIDs) wantKeys := map[spec.RoomID]map[string]string{
*roomID: {
string(key.Public().(ed25519.PublicKey)): userID.String(),
},
}
assert.Equal(t, wantKeys, userIDs)
// insert key that came in over federation // insert key that came in over federation
var gotPublicKey, key4 ed255192.PublicKey var gotPublicKey, key4 ed255192.PublicKey
@ -186,7 +195,7 @@ func TestUserRoomKeys(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
_, err = db.InsertUserRoomPublicKey(context.Background(), *userID, *reallyDoesNotExist, key4) _, err = db.InsertUserRoomPublicKey(context.Background(), *userID, *reallyDoesNotExist, key4)
assert.Error(t, err) assert.Error(t, err)
_, err = db.InsertUserRoomPrivateKey(context.Background(), *userID, *reallyDoesNotExist, key) _, err = db.InsertUserRoomPrivatePublicKey(context.Background(), *userID, *reallyDoesNotExist, key)
assert.Error(t, err) assert.Error(t, err)
}) })
} }

View file

@ -51,13 +51,13 @@ const insertUserRoomPublicKeySQL = `
const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
const selectUserNIDsSQL = `SELECT user_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE pseudo_id_pub_key IN ($1)` const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)`
type userRoomKeysStatements struct { type userRoomKeysStatements struct {
insertUserRoomPrivateKeyStmt *sql.Stmt insertUserRoomPrivateKeyStmt *sql.Stmt
insertUserRoomPublicKeyStmt *sql.Stmt insertUserRoomPublicKeyStmt *sql.Stmt
selectUserRoomKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt
selectUserNIDsStmt *sql.Stmt //selectUserNIDsStmt *sql.Stmt //prepared at runtime
} }
func CreateUserRoomKeysTable(db *sql.DB) error { func CreateUserRoomKeysTable(db *sql.DB) error {
@ -71,7 +71,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
{&s.insertUserRoomPrivateKeyStmt, insertUserRoomKeySQL}, {&s.insertUserRoomPrivateKeyStmt, insertUserRoomKeySQL},
{&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL},
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime //{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime
}.Prepare(db) }.Prepare(db)
} }
@ -102,25 +102,30 @@ func (s *userRoomKeysStatements) SelectUserRoomPrivateKey(
return result, err return result, err
} }
func (s *userRoomKeysStatements) BulkSelectUserNIDs( func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) {
ctx context.Context,
txn *sql.Tx, roomNIDs := make([]any, 0, len(senderKeys))
senderKeys [][]byte, var senders []any
) (map[string]types.EventStateKeyNID, error) { for roomNID := range senderKeys {
roomNIDs = append(roomNIDs, roomNID)
for _, key := range senderKeys[roomNID] {
senders = append(senders, []byte(key))
}
}
selectSQL := strings.Replace(selectUserNIDsSQL, "($2)", sqlutil.QueryVariadicOffset(len(senders), len(senderKeys)), 1)
selectSQL = strings.Replace(selectSQL, "($1)", sqlutil.QueryVariadic(len(senderKeys)), 1) // replace $1 with the roomNIDs
selectSQL := strings.Replace(selectUserNIDsSQL, "($1)", sqlutil.QueryVariadic(len(senderKeys)), 1)
selectStmt, err := txn.Prepare(selectSQL) selectStmt, err := txn.Prepare(selectSQL)
if err != nil { if err != nil {
return nil, err return nil, err
} }
params := make([]interface{}, len(senderKeys)) params := append(roomNIDs, senders...)
for i := range senderKeys {
params[i] = senderKeys[i]
}
stmt := sqlutil.TxStmt(txn, selectStmt) stmt := sqlutil.TxStmt(txn, selectStmt)
defer internal.CloseAndLogIfError(ctx, stmt, "failed to close transaction") defer internal.CloseAndLogIfError(ctx, stmt, "failed to close statement")
rows, err := stmt.QueryContext(ctx, params...) rows, err := stmt.QueryContext(ctx, params...)
if err != nil { if err != nil {
@ -128,14 +133,14 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs(
} }
defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows") defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows")
result := make(map[string]types.EventStateKeyNID, len(senderKeys)) result := make(map[string]types.UserRoomKeyPair, len(params))
var publicKey []byte var publicKey []byte
var userNID types.EventStateKeyNID userRoomKeyPair := types.UserRoomKeyPair{}
for rows.Next() { for rows.Next() {
if err = rows.Scan(&userNID, &publicKey); err != nil { if err = rows.Scan(&userRoomKeyPair.EventStateKeyNID, &userRoomKeyPair.RoomNID, &publicKey); err != nil {
return nil, err return nil, err
} }
result[string(publicKey)] = userNID result[string(publicKey)] = userRoomKeyPair
} }
return result, rows.Err() return result, rows.Err()
} }

View file

@ -189,7 +189,7 @@ type UserRoomKeys interface {
InsertUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (ed25519.PrivateKey, error) InsertUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (ed25519.PrivateKey, error)
InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (ed25519.PublicKey, error) InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (ed25519.PublicKey, error)
SelectUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error) SelectUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error)
BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys [][]byte) (map[string]types.EventStateKeyNID, error) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error)
} }
// StrippedEvent represents a stripped event for returning extracted content values. // StrippedEvent represents a stripped event for returning extracted content values.

View file

@ -78,14 +78,18 @@ func TestUserRoomKeysTable(t *testing.T) {
assert.Nil(t, gotKey) assert.Nil(t, gotKey)
// query user NIDs for senderKeys // query user NIDs for senderKeys
var gotKeys map[string]types.EventStateKeyNID var gotKeys map[string]types.UserRoomKeyPair
gotKeys, err = tab.BulkSelectUserNIDs(context.Background(), txn, [][]byte{key.Public().(ed25519.PublicKey), key3.Public().(ed25519.PublicKey)}) query := map[types.RoomNID][]ed25519.PublicKey{
roomNID: {key.Public().(ed25519.PublicKey), key3.Public().(ed25519.PublicKey)},
types.RoomNID(2): {key.Public().(ed25519.PublicKey), key3.Public().(ed25519.PublicKey)}, // doesn't exist
}
gotKeys, err = tab.BulkSelectUserNIDs(context.Background(), txn, query)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, gotKeys) assert.NotNil(t, gotKeys)
wantKeys := map[string]types.EventStateKeyNID{ wantKeys := map[string]types.UserRoomKeyPair{
string(key.Public().(ed25519.PublicKey)): userNID, string(key.Public().(ed25519.PublicKey)): {RoomNID: roomNID, EventStateKeyNID: userNID},
string(key3.Public().(ed25519.PublicKey)): userNID2, string(key3.Public().(ed25519.PublicKey)): {RoomNID: roomNID, EventStateKeyNID: userNID2},
} }
assert.Equal(t, wantKeys, gotKeys) assert.Equal(t, wantKeys, gotKeys)

View file

@ -44,6 +44,11 @@ type EventMetadata struct {
RoomNID RoomNID RoomNID RoomNID
} }
type UserRoomKeyPair struct {
RoomNID RoomNID
EventStateKeyNID EventStateKeyNID
}
// StateSnapshotNID is a numeric ID for the state at an event. // StateSnapshotNID is a numeric ID for the state at an event.
type StateSnapshotNID int64 type StateSnapshotNID int64