Accept spec.RoomID and spec.UserID instead of NIDs

This commit is contained in:
Till Faelligen 2023-06-06 16:10:30 +02:00
parent c24a052f01
commit 9781c54527
No known key found for this signature in database
GPG key ID: ACCDC9606D472758
3 changed files with 100 additions and 25 deletions

View file

@ -190,9 +190,9 @@ type Database interface {
ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver,
) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error)
InsertUserRoomPrivateKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error)
SelectUserRoomPrivateKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID) (key ed25519.PrivateKey, err error)
InsertUserRoomPublicKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (result ed25519.PublicKey, err error)
InsertUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error)
SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error)
InsertUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PublicKey) (result ed25519.PublicKey, err error)
SelectUserIDsForPublicKeys(ctx context.Context, publicKeys [][]byte) (map[string]string, error)
}

View file

@ -9,6 +9,7 @@ import (
"fmt"
"sort"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
@ -1595,10 +1596,25 @@ func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventS
// InsertUserRoomPrivateKey inserts a new user room key for the given user and room.
// Returns the newly inserted private key or an existing private key. If there is
// an error talking to the database, returns that error.
func (d *Database) InsertUserRoomPrivateKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) {
func (d *Database) InsertUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) {
uID := userID.String()
stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID})
if sErr != nil {
return nil, sErr
}
stateKeyNID := stateKeyNIDMap[uID]
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String())
if rErr != nil {
return rErr
}
if roomInfo == nil {
return eventutil.ErrRoomNoExists{}
}
var iErr error
result, iErr = d.UserRoomKeyTable.InsertUserRoomPrivateKey(ctx, txn, userNID, roomNID, key)
result, iErr = d.UserRoomKeyTable.InsertUserRoomPrivateKey(ctx, txn, stateKeyNID, roomInfo.RoomNID, key)
return iErr
})
return result, err
@ -1607,10 +1623,25 @@ func (d *Database) InsertUserRoomPrivateKey(ctx context.Context, userNID types.E
// InsertUserRoomPublicKey inserts a new user room key for the given user and room.
// Returns the newly inserted public key or an existing public key. If there is
// an error talking to the database, returns that error.
func (d *Database) InsertUserRoomPublicKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) {
func (d *Database) InsertUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) {
uID := userID.String()
stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID})
if sErr != nil {
return nil, sErr
}
stateKeyNID := stateKeyNIDMap[uID]
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String())
if rErr != nil {
return rErr
}
if roomInfo == nil {
return eventutil.ErrRoomNoExists{}
}
var iErr error
result, iErr = d.UserRoomKeyTable.InsertUserRoomPublicKey(ctx, txn, userNID, roomNID, key)
result, iErr = d.UserRoomKeyTable.InsertUserRoomPublicKey(ctx, txn, stateKeyNID, roomInfo.RoomNID, key)
return iErr
})
return result, err
@ -1619,10 +1650,24 @@ func (d *Database) InsertUserRoomPublicKey(ctx context.Context, userNID types.Ev
// SelectUserRoomPrivateKey queries the users room private key.
// If no key exists, returns no key and no error. Otherwise returns
// the key and a database error, if any.
func (d *Database) SelectUserRoomPrivateKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID) (key ed25519.PrivateKey, err error) {
func (d *Database) SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error) {
uID := userID.String()
stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID})
if sErr != nil {
return nil, sErr
}
stateKeyNID := stateKeyNIDMap[uID]
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var sErr error
key, sErr = d.UserRoomKeyTable.SelectUserRoomPrivateKey(ctx, txn, userNID, roomNID)
roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String())
if rErr != nil {
return rErr
}
if roomInfo == nil {
return nil
}
key, sErr = d.UserRoomKeyTable.SelectUserRoomPrivateKey(ctx, txn, stateKeyNID, roomInfo.RoomNID)
if !errors.Is(sErr, sql.ErrNoRows) {
return sErr
}

View file

@ -7,7 +7,8 @@ import (
"time"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/stretchr/testify/assert"
ed255192 "golang.org/x/crypto/ed25519"
@ -26,32 +27,42 @@ func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Dat
connStr, clearDB := test.PrepareDBConnectionString(t, dbType)
dbOpts := &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}
db, err := sqlutil.Open(dbOpts, sqlutil.NewExclusiveWriter())
writer := sqlutil.NewExclusiveWriter()
db, err := sqlutil.Open(dbOpts, writer)
assert.NoError(t, err)
var membershipTable tables.Membership
var stateKeyTable tables.EventStateKeys
var userRoomKeys tables.UserRoomKeys
var roomsTable tables.Rooms
switch dbType {
case test.DBTypePostgres:
err = postgres.CreateRoomsTable(db)
assert.NoError(t, err)
err = postgres.CreateEventStateKeysTable(db)
assert.NoError(t, err)
err = postgres.CreateMembershipTable(db)
assert.NoError(t, err)
err = postgres.CreateUserRoomKeysTable(db)
assert.NoError(t, err)
roomsTable, err = postgres.PrepareRoomsTable(db)
assert.NoError(t, err)
membershipTable, err = postgres.PrepareMembershipTable(db)
assert.NoError(t, err)
stateKeyTable, err = postgres.PrepareEventStateKeysTable(db)
assert.NoError(t, err)
userRoomKeys, err = postgres.PrepareUserRoomKeysTable(db)
case test.DBTypeSQLite:
err = sqlite3.CreateRoomsTable(db)
assert.NoError(t, err)
err = sqlite3.CreateEventStateKeysTable(db)
assert.NoError(t, err)
err = sqlite3.CreateMembershipTable(db)
assert.NoError(t, err)
err = sqlite3.CreateUserRoomKeysTable(db)
assert.NoError(t, err)
roomsTable, err = sqlite3.PrepareRoomsTable(db)
assert.NoError(t, err)
membershipTable, err = sqlite3.PrepareMembershipTable(db)
assert.NoError(t, err)
stateKeyTable, err = sqlite3.PrepareEventStateKeysTable(db)
@ -62,14 +73,15 @@ func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Dat
cache := caching.NewRistrettoCache(8*1024*1024, time.Hour, false)
evDb := shared.EventDatabase{EventStateKeysTable: stateKeyTable, Cache: cache}
evDb := shared.EventDatabase{EventStateKeysTable: stateKeyTable, Cache: cache, Writer: writer}
return &shared.Database{
DB: db,
EventDatabase: evDb,
MembershipTable: membershipTable,
UserRoomKeyTable: userRoomKeys,
Writer: sqlutil.NewExclusiveWriter(),
RoomsTable: roomsTable,
Writer: writer,
Cache: cache,
}, func() {
clearDB()
@ -113,36 +125,47 @@ func Test_GetLeftUsers(t *testing.T) {
func TestUserRoomKeys(t *testing.T) {
ctx := context.Background()
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
userID, err := spec.NewUserID(alice.ID, true)
assert.NoError(t, err)
roomID, err := spec.NewRoomID(room.ID)
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateRoomserverDatabase(t, dbType)
defer close()
roomNID := types.RoomNID(1)
// create a room NID so we can query the room
_, err = db.RoomsTable.InsertRoomNID(ctx, nil, roomID.String(), gomatrixserverlib.RoomVersionV10)
assert.NoError(t, err)
doesNotExist, err := spec.NewRoomID("!doesnotexist:localhost")
assert.NoError(t, err)
_, err = db.RoomsTable.InsertRoomNID(ctx, nil, doesNotExist.String(), gomatrixserverlib.RoomVersionV10)
assert.NoError(t, err)
_, key, err := ed25519.GenerateKey(nil)
assert.NoError(t, err)
// insert dummy event state keys
dummy := test.NewUser(t)
userNID, err := db.GetOrCreateEventStateKeyNID(ctx, &dummy.ID)
assert.NoError(t, err)
gotKey, err := db.InsertUserRoomPrivateKey(ctx, userNID, roomNID, key)
gotKey, err := db.InsertUserRoomPrivateKey(ctx, *userID, *roomID, key)
assert.NoError(t, err)
assert.Equal(t, gotKey, key)
// again, this shouldn't result in an error, but return the existing key
_, key2, err := ed25519.GenerateKey(nil)
assert.NoError(t, err)
gotKey, err = db.InsertUserRoomPrivateKey(context.Background(), userNID, roomNID, key2)
gotKey, err = db.InsertUserRoomPrivateKey(ctx, *userID, *roomID, key2)
assert.NoError(t, err)
assert.Equal(t, gotKey, key)
gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), userNID, roomNID)
gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *roomID)
assert.NoError(t, err)
assert.Equal(t, key, gotKey)
// Key doesn't exist, we shouldn't get anything back
gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), userNID, 2)
assert.NoError(t, err)
gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *doesNotExist)
assert.NoError(t, err)
assert.Nil(t, gotKey)
@ -154,9 +177,16 @@ func TestUserRoomKeys(t *testing.T) {
var gotPublicKey, key4 ed255192.PublicKey
key4, _, err = ed25519.GenerateKey(nil)
assert.NoError(t, err)
gotPublicKey, err = db.InsertUserRoomPublicKey(context.Background(), userNID, 2, key4)
gotPublicKey, err = db.InsertUserRoomPublicKey(context.Background(), *userID, *doesNotExist, key4)
assert.NoError(t, err)
assert.Equal(t, key4, gotPublicKey)
// test invalid room
reallyDoesNotExist, err := spec.NewRoomID("!reallydoesnotexist:localhost")
assert.NoError(t, err)
_, err = db.InsertUserRoomPublicKey(context.Background(), *userID, *reallyDoesNotExist, key4)
assert.Error(t, err)
_, err = db.InsertUserRoomPrivateKey(context.Background(), *userID, *reallyDoesNotExist, key)
assert.Error(t, err)
})
}