diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index e915767cb..d020450ff 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -190,9 +190,9 @@ type Database interface { ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, ) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error) - InsertUserRoomPrivateKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) - SelectUserRoomPrivateKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID) (key ed25519.PrivateKey, err error) - InsertUserRoomPublicKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) + InsertUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) + SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error) + InsertUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys [][]byte) (map[string]string, error) } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 108bfd41d..62914fad7 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -9,6 +9,7 @@ import ( "fmt" "sort" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" @@ -1595,10 +1596,25 @@ func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventS // InsertUserRoomPrivateKey inserts a new user room key for the given user and room. // Returns the newly inserted private key or an existing private key. If there is // an error talking to the database, returns that error. -func (d *Database) InsertUserRoomPrivateKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) { +func (d *Database) InsertUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) { + uID := userID.String() + stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID}) + if sErr != nil { + return nil, sErr + } + stateKeyNID := stateKeyNIDMap[uID] + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String()) + if rErr != nil { + return rErr + } + if roomInfo == nil { + return eventutil.ErrRoomNoExists{} + } + var iErr error - result, iErr = d.UserRoomKeyTable.InsertUserRoomPrivateKey(ctx, txn, userNID, roomNID, key) + result, iErr = d.UserRoomKeyTable.InsertUserRoomPrivateKey(ctx, txn, stateKeyNID, roomInfo.RoomNID, key) return iErr }) return result, err @@ -1607,10 +1623,25 @@ func (d *Database) InsertUserRoomPrivateKey(ctx context.Context, userNID types.E // InsertUserRoomPublicKey inserts a new user room key for the given user and room. // Returns the newly inserted public key or an existing public key. If there is // an error talking to the database, returns that error. -func (d *Database) InsertUserRoomPublicKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) { +func (d *Database) InsertUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) { + uID := userID.String() + stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID}) + if sErr != nil { + return nil, sErr + } + stateKeyNID := stateKeyNIDMap[uID] + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String()) + if rErr != nil { + return rErr + } + if roomInfo == nil { + return eventutil.ErrRoomNoExists{} + } + var iErr error - result, iErr = d.UserRoomKeyTable.InsertUserRoomPublicKey(ctx, txn, userNID, roomNID, key) + result, iErr = d.UserRoomKeyTable.InsertUserRoomPublicKey(ctx, txn, stateKeyNID, roomInfo.RoomNID, key) return iErr }) return result, err @@ -1619,10 +1650,24 @@ func (d *Database) InsertUserRoomPublicKey(ctx context.Context, userNID types.Ev // SelectUserRoomPrivateKey queries the users room private key. // If no key exists, returns no key and no error. Otherwise returns // the key and a database error, if any. -func (d *Database) SelectUserRoomPrivateKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID) (key ed25519.PrivateKey, err error) { +func (d *Database) SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error) { + uID := userID.String() + stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID}) + if sErr != nil { + return nil, sErr + } + stateKeyNID := stateKeyNIDMap[uID] + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - var sErr error - key, sErr = d.UserRoomKeyTable.SelectUserRoomPrivateKey(ctx, txn, userNID, roomNID) + roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String()) + if rErr != nil { + return rErr + } + if roomInfo == nil { + return nil + } + + key, sErr = d.UserRoomKeyTable.SelectUserRoomPrivateKey(ctx, txn, stateKeyNID, roomInfo.RoomNID) if !errors.Is(sErr, sql.ErrNoRows) { return sErr } diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go index 8c60d6821..ce1c46c25 100644 --- a/roomserver/storage/shared/storage_test.go +++ b/roomserver/storage/shared/storage_test.go @@ -7,7 +7,8 @@ import ( "time" "github.com/matrix-org/dendrite/internal/caching" - "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" ed255192 "golang.org/x/crypto/ed25519" @@ -26,32 +27,42 @@ func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Dat connStr, clearDB := test.PrepareDBConnectionString(t, dbType) dbOpts := &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)} - db, err := sqlutil.Open(dbOpts, sqlutil.NewExclusiveWriter()) + writer := sqlutil.NewExclusiveWriter() + db, err := sqlutil.Open(dbOpts, writer) assert.NoError(t, err) var membershipTable tables.Membership var stateKeyTable tables.EventStateKeys var userRoomKeys tables.UserRoomKeys + var roomsTable tables.Rooms switch dbType { case test.DBTypePostgres: + err = postgres.CreateRoomsTable(db) + assert.NoError(t, err) err = postgres.CreateEventStateKeysTable(db) assert.NoError(t, err) err = postgres.CreateMembershipTable(db) assert.NoError(t, err) err = postgres.CreateUserRoomKeysTable(db) assert.NoError(t, err) + roomsTable, err = postgres.PrepareRoomsTable(db) + assert.NoError(t, err) membershipTable, err = postgres.PrepareMembershipTable(db) assert.NoError(t, err) stateKeyTable, err = postgres.PrepareEventStateKeysTable(db) assert.NoError(t, err) userRoomKeys, err = postgres.PrepareUserRoomKeysTable(db) case test.DBTypeSQLite: + err = sqlite3.CreateRoomsTable(db) + assert.NoError(t, err) err = sqlite3.CreateEventStateKeysTable(db) assert.NoError(t, err) err = sqlite3.CreateMembershipTable(db) assert.NoError(t, err) err = sqlite3.CreateUserRoomKeysTable(db) assert.NoError(t, err) + roomsTable, err = sqlite3.PrepareRoomsTable(db) + assert.NoError(t, err) membershipTable, err = sqlite3.PrepareMembershipTable(db) assert.NoError(t, err) stateKeyTable, err = sqlite3.PrepareEventStateKeysTable(db) @@ -62,14 +73,15 @@ func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Dat cache := caching.NewRistrettoCache(8*1024*1024, time.Hour, false) - evDb := shared.EventDatabase{EventStateKeysTable: stateKeyTable, Cache: cache} + evDb := shared.EventDatabase{EventStateKeysTable: stateKeyTable, Cache: cache, Writer: writer} return &shared.Database{ DB: db, EventDatabase: evDb, MembershipTable: membershipTable, UserRoomKeyTable: userRoomKeys, - Writer: sqlutil.NewExclusiveWriter(), + RoomsTable: roomsTable, + Writer: writer, Cache: cache, }, func() { clearDB() @@ -113,36 +125,47 @@ func Test_GetLeftUsers(t *testing.T) { func TestUserRoomKeys(t *testing.T) { ctx := context.Background() + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + + userID, err := spec.NewUserID(alice.ID, true) + assert.NoError(t, err) + roomID, err := spec.NewRoomID(room.ID) + assert.NoError(t, err) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateRoomserverDatabase(t, dbType) defer close() - roomNID := types.RoomNID(1) + // create a room NID so we can query the room + _, err = db.RoomsTable.InsertRoomNID(ctx, nil, roomID.String(), gomatrixserverlib.RoomVersionV10) + assert.NoError(t, err) + doesNotExist, err := spec.NewRoomID("!doesnotexist:localhost") + assert.NoError(t, err) + _, err = db.RoomsTable.InsertRoomNID(ctx, nil, doesNotExist.String(), gomatrixserverlib.RoomVersionV10) + assert.NoError(t, err) + _, key, err := ed25519.GenerateKey(nil) assert.NoError(t, err) - // insert dummy event state keys - dummy := test.NewUser(t) - userNID, err := db.GetOrCreateEventStateKeyNID(ctx, &dummy.ID) - assert.NoError(t, err) - - gotKey, err := db.InsertUserRoomPrivateKey(ctx, userNID, roomNID, key) + gotKey, err := db.InsertUserRoomPrivateKey(ctx, *userID, *roomID, key) assert.NoError(t, err) assert.Equal(t, gotKey, key) // again, this shouldn't result in an error, but return the existing key _, key2, err := ed25519.GenerateKey(nil) assert.NoError(t, err) - gotKey, err = db.InsertUserRoomPrivateKey(context.Background(), userNID, roomNID, key2) + gotKey, err = db.InsertUserRoomPrivateKey(ctx, *userID, *roomID, key2) assert.NoError(t, err) assert.Equal(t, gotKey, key) - gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), userNID, roomNID) + gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *roomID) assert.NoError(t, err) assert.Equal(t, key, gotKey) // Key doesn't exist, we shouldn't get anything back - gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), userNID, 2) + assert.NoError(t, err) + gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *doesNotExist) assert.NoError(t, err) assert.Nil(t, gotKey) @@ -154,9 +177,16 @@ func TestUserRoomKeys(t *testing.T) { var gotPublicKey, key4 ed255192.PublicKey key4, _, err = ed25519.GenerateKey(nil) assert.NoError(t, err) - gotPublicKey, err = db.InsertUserRoomPublicKey(context.Background(), userNID, 2, key4) + gotPublicKey, err = db.InsertUserRoomPublicKey(context.Background(), *userID, *doesNotExist, key4) assert.NoError(t, err) assert.Equal(t, key4, gotPublicKey) + // test invalid room + reallyDoesNotExist, err := spec.NewRoomID("!reallydoesnotexist:localhost") + assert.NoError(t, err) + _, err = db.InsertUserRoomPublicKey(context.Background(), *userID, *reallyDoesNotExist, key4) + assert.Error(t, err) + _, err = db.InsertUserRoomPrivateKey(context.Background(), *userID, *reallyDoesNotExist, key) + assert.Error(t, err) }) }