Add roomID/roomNID to query for public keys

This commit is contained in:
Till Faelligen 2023-06-07 12:35:34 +02:00
parent f2f0bdd2f4
commit 5dac4ae017
No known key found for this signature in database
GPG key ID: ACCDC9606D472758
8 changed files with 112 additions and 55 deletions

View file

@ -194,10 +194,10 @@ type Database interface {
ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver,
) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error)
InsertUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error)
InsertUserRoomPrivatePublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error)
SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error)
InsertUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PublicKey) (result ed25519.PublicKey, err error)
SelectUserIDsForPublicKeys(ctx context.Context, publicKeys [][]byte) (map[string]string, error)
SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (map[spec.RoomID]map[string]string, error)
}
type RoomDatabase interface {

View file

@ -51,7 +51,7 @@ const insertUserRoomPublicKeySQL = `
const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
const selectUserNIDsSQL = `SELECT user_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE pseudo_id_pub_key = ANY($1)`
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE pseudo_id_pub_key = ANY($1)`
type userRoomKeysStatements struct {
insertUserRoomPrivateKeyStmt *sql.Stmt
@ -102,27 +102,31 @@ func (s *userRoomKeysStatements) SelectUserRoomPrivateKey(
return result, err
}
func (s *userRoomKeysStatements) BulkSelectUserNIDs(
ctx context.Context,
txn *sql.Tx,
senderKeys [][]byte,
) (map[string]types.EventStateKeyNID, error) {
func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) {
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserNIDsStmt)
rows, err := stmt.QueryContext(ctx, pq.Array(senderKeys))
roomNIDs := make([]types.RoomNID, 0, len(senderKeys))
var senders [][]byte
for roomNID := range senderKeys {
roomNIDs = append(roomNIDs, roomNID)
for _, key := range senderKeys[roomNID] {
senders = append(senders, key)
}
}
rows, err := stmt.QueryContext(ctx, pq.Array(senders))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows")
result := make(map[string]types.EventStateKeyNID, len(senderKeys))
result := make(map[string]types.UserRoomKeyPair, len(senders)+len(roomNIDs))
var publicKey []byte
var userNID types.EventStateKeyNID
userRoomKeyPair := types.UserRoomKeyPair{}
for rows.Next() {
if err = rows.Scan(&userNID, &publicKey); err != nil {
if err = rows.Scan(&userRoomKeyPair.EventStateKeyNID, &userRoomKeyPair.RoomNID, &publicKey); err != nil {
return nil, err
}
result[string(publicKey)] = userNID
result[string(publicKey)] = userRoomKeyPair
}
return result, rows.Err()
}

View file

@ -13,6 +13,7 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/internal/caching"
@ -1616,7 +1617,7 @@ func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventS
// InsertUserRoomPrivateKey inserts a new user room key for the given user and room.
// Returns the newly inserted private key or an existing private key. If there is
// an error talking to the database, returns that error.
func (d *Database) InsertUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) {
func (d *Database) InsertUserRoomPrivatePublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) {
uID := userID.String()
stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID})
if sErr != nil {
@ -1696,28 +1697,57 @@ func (d *Database) SelectUserRoomPrivateKey(ctx context.Context, userID spec.Use
return
}
// SelectUserIDsForPublicKeys returns a map from senderKey -> userID
func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys [][]byte) (result map[string]string, err error) {
result = make(map[string]string, len(publicKeys))
// SelectUserIDsForPublicKeys returns a map from roomID -> map from senderKey -> userID
func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (result map[spec.RoomID]map[string]string, err error) {
result = make(map[spec.RoomID]map[string]string, len(publicKeys))
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var sErr error
var userNIDKeyMap map[string]types.EventStateKeyNID
userNIDKeyMap, sErr = d.UserRoomKeyTable.BulkSelectUserNIDs(ctx, txn, publicKeys)
// map all roomIDs to roomNIDs
query := make(map[types.RoomNID][]ed25519.PublicKey)
rooms := make(map[types.RoomNID]spec.RoomID)
for roomID, keys := range publicKeys {
roomNID, ok := d.Cache.GetRoomServerRoomNID(roomID.String())
if !ok {
roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String())
if rErr != nil {
return rErr
}
if roomInfo == nil {
logrus.Warnf("missing room info for %s, there will be missing users in the response", roomID.String())
continue
}
roomNID = roomInfo.RoomNID
}
query[roomNID] = keys
rooms[roomNID] = roomID
}
// get the user room key pars
userRoomKeyPairMap, sErr := d.UserRoomKeyTable.BulkSelectUserNIDs(ctx, txn, query)
if sErr != nil {
return sErr
}
nids := make([]types.EventStateKeyNID, 0, len(userNIDKeyMap))
for _, nid := range userNIDKeyMap {
nids = append(nids, nid)
nids := make([]types.EventStateKeyNID, 0, len(userRoomKeyPairMap))
for _, nid := range userRoomKeyPairMap {
nids = append(nids, nid.EventStateKeyNID)
}
// get the userIDs
nidMAP, seErr := d.EventStateKeys(ctx, nids)
if seErr != nil {
return seErr
}
for publicKey, userNID := range userNIDKeyMap {
userID := nidMAP[userNID]
result[publicKey] = userID
// build the result map (roomID -> map publicKey -> userID)
for publicKey, userRoomKeyPair := range userRoomKeyPairMap {
userID := nidMAP[userRoomKeyPair.EventStateKeyNID]
roomID := rooms[userRoomKeyPair.RoomNID]
resMap, exists := result[roomID]
if !exists {
resMap = map[string]string{}
}
resMap[publicKey] = userID
result[roomID] = resMap
}
return nil

View file

@ -148,14 +148,14 @@ func TestUserRoomKeys(t *testing.T) {
_, key, err := ed25519.GenerateKey(nil)
assert.NoError(t, err)
gotKey, err := db.InsertUserRoomPrivateKey(ctx, *userID, *roomID, key)
gotKey, err := db.InsertUserRoomPrivatePublicKey(ctx, *userID, *roomID, key)
assert.NoError(t, err)
assert.Equal(t, gotKey, key)
// again, this shouldn't result in an error, but return the existing key
_, key2, err := ed25519.GenerateKey(nil)
assert.NoError(t, err)
gotKey, err = db.InsertUserRoomPrivateKey(ctx, *userID, *roomID, key2)
gotKey, err = db.InsertUserRoomPrivatePublicKey(ctx, *userID, *roomID, key2)
assert.NoError(t, err)
assert.Equal(t, gotKey, key)
@ -169,9 +169,18 @@ func TestUserRoomKeys(t *testing.T) {
assert.NoError(t, err)
assert.Nil(t, gotKey)
userIDs, err := db.SelectUserIDsForPublicKeys(ctx, [][]byte{key.Public().(ed25519.PublicKey)})
queryUserIDs := map[spec.RoomID][]ed25519.PublicKey{
*roomID: {key.Public().(ed25519.PublicKey)},
}
userIDs, err := db.SelectUserIDsForPublicKeys(ctx, queryUserIDs)
assert.NoError(t, err)
assert.NotNil(t, userIDs)
wantKeys := map[spec.RoomID]map[string]string{
*roomID: {
string(key.Public().(ed25519.PublicKey)): userID.String(),
},
}
assert.Equal(t, wantKeys, userIDs)
// insert key that came in over federation
var gotPublicKey, key4 ed255192.PublicKey
@ -186,7 +195,7 @@ func TestUserRoomKeys(t *testing.T) {
assert.NoError(t, err)
_, err = db.InsertUserRoomPublicKey(context.Background(), *userID, *reallyDoesNotExist, key4)
assert.Error(t, err)
_, err = db.InsertUserRoomPrivateKey(context.Background(), *userID, *reallyDoesNotExist, key)
_, err = db.InsertUserRoomPrivatePublicKey(context.Background(), *userID, *reallyDoesNotExist, key)
assert.Error(t, err)
})
}

View file

@ -51,13 +51,13 @@ const insertUserRoomPublicKeySQL = `
const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
const selectUserNIDsSQL = `SELECT user_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE pseudo_id_pub_key IN ($1)`
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)`
type userRoomKeysStatements struct {
insertUserRoomPrivateKeyStmt *sql.Stmt
insertUserRoomPublicKeyStmt *sql.Stmt
selectUserRoomKeyStmt *sql.Stmt
selectUserNIDsStmt *sql.Stmt
//selectUserNIDsStmt *sql.Stmt //prepared at runtime
}
func CreateUserRoomKeysTable(db *sql.DB) error {
@ -71,7 +71,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
{&s.insertUserRoomPrivateKeyStmt, insertUserRoomKeySQL},
{&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL},
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime
//{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime
}.Prepare(db)
}
@ -102,25 +102,30 @@ func (s *userRoomKeysStatements) SelectUserRoomPrivateKey(
return result, err
}
func (s *userRoomKeysStatements) BulkSelectUserNIDs(
ctx context.Context,
txn *sql.Tx,
senderKeys [][]byte,
) (map[string]types.EventStateKeyNID, error) {
func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) {
roomNIDs := make([]any, 0, len(senderKeys))
var senders []any
for roomNID := range senderKeys {
roomNIDs = append(roomNIDs, roomNID)
for _, key := range senderKeys[roomNID] {
senders = append(senders, []byte(key))
}
}
selectSQL := strings.Replace(selectUserNIDsSQL, "($2)", sqlutil.QueryVariadicOffset(len(senders), len(senderKeys)), 1)
selectSQL = strings.Replace(selectSQL, "($1)", sqlutil.QueryVariadic(len(senderKeys)), 1) // replace $1 with the roomNIDs
selectSQL := strings.Replace(selectUserNIDsSQL, "($1)", sqlutil.QueryVariadic(len(senderKeys)), 1)
selectStmt, err := txn.Prepare(selectSQL)
if err != nil {
return nil, err
}
params := make([]interface{}, len(senderKeys))
for i := range senderKeys {
params[i] = senderKeys[i]
}
params := append(roomNIDs, senders...)
stmt := sqlutil.TxStmt(txn, selectStmt)
defer internal.CloseAndLogIfError(ctx, stmt, "failed to close transaction")
defer internal.CloseAndLogIfError(ctx, stmt, "failed to close statement")
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
@ -128,14 +133,14 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs(
}
defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows")
result := make(map[string]types.EventStateKeyNID, len(senderKeys))
result := make(map[string]types.UserRoomKeyPair, len(params))
var publicKey []byte
var userNID types.EventStateKeyNID
userRoomKeyPair := types.UserRoomKeyPair{}
for rows.Next() {
if err = rows.Scan(&userNID, &publicKey); err != nil {
if err = rows.Scan(&userRoomKeyPair.EventStateKeyNID, &userRoomKeyPair.RoomNID, &publicKey); err != nil {
return nil, err
}
result[string(publicKey)] = userNID
result[string(publicKey)] = userRoomKeyPair
}
return result, rows.Err()
}

View file

@ -189,7 +189,7 @@ type UserRoomKeys interface {
InsertUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (ed25519.PrivateKey, error)
InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (ed25519.PublicKey, error)
SelectUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error)
BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys [][]byte) (map[string]types.EventStateKeyNID, error)
BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error)
}
// StrippedEvent represents a stripped event for returning extracted content values.

View file

@ -78,14 +78,18 @@ func TestUserRoomKeysTable(t *testing.T) {
assert.Nil(t, gotKey)
// query user NIDs for senderKeys
var gotKeys map[string]types.EventStateKeyNID
gotKeys, err = tab.BulkSelectUserNIDs(context.Background(), txn, [][]byte{key.Public().(ed25519.PublicKey), key3.Public().(ed25519.PublicKey)})
var gotKeys map[string]types.UserRoomKeyPair
query := map[types.RoomNID][]ed25519.PublicKey{
roomNID: {key.Public().(ed25519.PublicKey), key3.Public().(ed25519.PublicKey)},
types.RoomNID(2): {key.Public().(ed25519.PublicKey), key3.Public().(ed25519.PublicKey)}, // doesn't exist
}
gotKeys, err = tab.BulkSelectUserNIDs(context.Background(), txn, query)
assert.NoError(t, err)
assert.NotNil(t, gotKeys)
wantKeys := map[string]types.EventStateKeyNID{
string(key.Public().(ed25519.PublicKey)): userNID,
string(key3.Public().(ed25519.PublicKey)): userNID2,
wantKeys := map[string]types.UserRoomKeyPair{
string(key.Public().(ed25519.PublicKey)): {RoomNID: roomNID, EventStateKeyNID: userNID},
string(key3.Public().(ed25519.PublicKey)): {RoomNID: roomNID, EventStateKeyNID: userNID2},
}
assert.Equal(t, wantKeys, gotKeys)

View file

@ -44,6 +44,11 @@ type EventMetadata struct {
RoomNID RoomNID
}
type UserRoomKeyPair struct {
RoomNID RoomNID
EventStateKeyNID EventStateKeyNID
}
// StateSnapshotNID is a numeric ID for the state at an event.
type StateSnapshotNID int64