diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 4ad308a3e..8e26854a1 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -2,6 +2,7 @@ package api import ( "context" + "crypto/ed25519" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" @@ -66,6 +67,9 @@ type RoomserverInternalAPI interface { req *QueryAuthChainRequest, res *QueryAuthChainResponse, ) error + + // GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created. + GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error) } type InputRoomEventsAPI interface { diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 1e5aba57c..7943ae5c0 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -2,6 +2,7 @@ package internal import ( "context" + "crypto/ed25519" "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" @@ -271,3 +272,23 @@ func (r *RoomserverInternalAPI) PerformForget( ) error { return r.Forgetter.PerformForget(ctx, req, resp) } + +// GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created. +func (r *RoomserverInternalAPI) GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error) { + key, err := r.DB.SelectUserRoomPrivateKey(ctx, userID, roomID) + if err != nil { + return nil, err + } + // no key found, create one + if len(key) == 0 { + _, key, err = ed25519.GenerateKey(nil) + if err != nil { + return nil, err + } + key, err = r.DB.InsertUserRoomPrivatePublicKey(ctx, userID, roomID, key) + if err != nil { + return nil, err + } + } + return key, nil +} diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go index b81627b04..a79b369a4 100644 --- a/roomserver/internal/perform/perform_create_room.go +++ b/roomserver/internal/perform/perform_create_room.go @@ -353,7 +353,30 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo SendAsServer: api.DoNotSendToOtherServers, }) } - if err = api.SendInputRoomEvents(ctx, c.RSAPI, userID.Domain(), inputs, false); err != nil { + + // first send the `m.room.create` event, so we have a roomNID + if err = api.SendInputRoomEvents(ctx, c.RSAPI, userID.Domain(), inputs[:1], false); err != nil { + util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + + // create user room key if needed + if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { + _, err = c.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("GetOrCreateUserRoomPrivateKey failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + } + + // send the remaining events + if err = api.SendInputRoomEvents(ctx, c.RSAPI, userID.Domain(), inputs[1:], false); err != nil { util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") return "", &util.JSONResponse{ Code: http.StatusInternalServerError, diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 8f188994f..babd5f812 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -188,6 +188,14 @@ func (r *Inviter) PerformInvite( inviteEvent = event } + // if we invited a local user, we can also create a user room key, if it doesn't exist yet. + if isTargetLocal && event.Version() == gomatrixserverlib.RoomVersionPseudoIDs { + _, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *invitedUser, *validRoomID) + if err != nil { + return fmt.Errorf("failed to get user room private key: %w", err) + } + } + // Send the invite event to the roomserver input stream. This will // notify existing users in the room about the invite, update the // membership table and ensure that the event is ready and available diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index ffabeabab..11d2bf395 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -308,6 +308,15 @@ func (r *Joiner) performJoinRoomByID( switch err.(type) { case nil: + // create user room key if needed + if buildRes.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { + _, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *userID, *roomID) + if err != nil { + logrus.WithError(err).Error("GetOrCreateUserRoomPrivateKey failed") + return "", "", fmt.Errorf("failed to get user room private key: %w", err) + } + } + // The room join is local. Send the new join event into the // roomserver. First of all check that the user isn't already // a member of the room. This is best-effort (as in we won't diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 447b90552..7156c11cc 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -31,6 +31,7 @@ type Database interface { UserRoomKeys // Do we support processing input events for more than one room at a time? SupportsConcurrentRoomInputs() bool + AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error) // RoomInfo returns room information for the given room ID, or nil if there is no room. RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error) @@ -213,6 +214,7 @@ type UserRoomKeys interface { type RoomDatabase interface { EventDatabase UserRoomKeys + AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error) // RoomInfo returns room information for the given room ID, or nil if there is no room. RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error) diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 1cd0f23a2..61a3520a4 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -662,6 +662,26 @@ func (d *Database) IsEventRejected(ctx context.Context, roomNID types.RoomNID, e return d.EventsTable.SelectEventRejected(ctx, nil, roomNID, eventID) } +func (d *Database) AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error) { + // This should already be checked, let's check it anyway. + _, err = gomatrixserverlib.GetRoomVersion(roomVersion) + if err != nil { + return 0, err + } + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + roomNID, err = d.assignRoomNID(ctx, txn, roomID.String(), roomVersion) + if err != nil { + return err + } + return nil + }) + if err != nil { + return 0, err + } + // Not setting caches, as assignRoomNID already does this + return roomNID, err +} + // GetOrCreateRoomInfo gets or creates a new RoomInfo, which is only safe to use with functions only needing a roomVersion or roomNID. func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event gomatrixserverlib.PDU) (roomInfo *types.RoomInfo, err error) { // Get the default room version. If the client doesn't supply a room_version @@ -1692,7 +1712,7 @@ func (d *Database) SelectUserRoomPrivateKey(ctx context.Context, userID spec.Use return rErr } if roomInfo == nil { - return nil + return eventutil.ErrRoomNoExists{} } key, sErr = d.UserRoomKeyTable.SelectUserRoomPrivateKey(ctx, txn, stateKeyNID, roomInfo.RoomNID) diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go index 8f83af3bb..c7b915c7d 100644 --- a/roomserver/storage/shared/storage_test.go +++ b/roomserver/storage/shared/storage_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" @@ -204,3 +205,24 @@ func TestUserRoomKeys(t *testing.T) { assert.Error(t, err) }) } + +func TestAssignRoomNID(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + + roomID, err := spec.NewRoomID(room.ID) + assert.NoError(t, err) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRoomserverDatabase(t, dbType) + defer close() + + nid, err := db.AssignRoomNID(ctx, *roomID, room.Version) + assert.NoError(t, err) + assert.Greater(t, nid, types.EventNID(0)) + + _, err = db.AssignRoomNID(ctx, spec.RoomID{}, "notaroomversion") + assert.Error(t, err) + }) +}