From c11e7061f3cac4af7f9f15f1ddba49ee2fa32b17 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Mon, 12 Jun 2023 17:12:55 +0100 Subject: [PATCH] Convert between senderID/userID using the db/cache only when necessary --- clientapi/routing/joinroom_test.go | 5 ++++ .../internal/perform/perform_create_room.go | 13 +++++----- roomserver/internal/query/query.go | 19 +++++++++++++- roomserver/storage/interface.go | 1 + roomserver/storage/shared/storage.go | 25 +++++++++++++++++++ 5 files changed, 55 insertions(+), 8 deletions(-) diff --git a/clientapi/routing/joinroom_test.go b/clientapi/routing/joinroom_test.go index 0ddff8a95..e4c52786f 100644 --- a/clientapi/routing/joinroom_test.go +++ b/clientapi/routing/joinroom_test.go @@ -9,6 +9,7 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" @@ -33,6 +34,10 @@ func TestJoinRoomByIDOrAlias(t *testing.T) { cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + _, openErr := storage.Open(processCtx.Context(), cm, &cfg.RoomServer.Database, caches) + if openErr != nil { + t.Fatal(openErr) + } natsInstance := jetstream.NATSInstance{} rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go index 64c1d6a84..59b1efa61 100644 --- a/roomserver/internal/perform/perform_create_room.go +++ b/roomserver/internal/perform/perform_create_room.go @@ -63,13 +63,12 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo } } } - senderID, err := c.RSAPI.QuerySenderIDForUser(ctx, roomID.String(), userID) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user") - return "", &util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } + var senderID spec.SenderID + if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { + // TODO: pseudoIDs - generate senderID kere! + senderID = "pseudo_id.sender.key" + } else { + senderID = spec.SenderID(userID.String()) } createContent["creator"] = senderID createContent["room_version"] = createRequest.RoomVersion diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 439879354..de8f597c7 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -990,9 +990,26 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro } func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { - return r.DB.GetSenderIDForUser(ctx, roomID, userID) + roomInfo, err := r.DB.GetOrCreateRoomInfoFromID(ctx, roomID) + if err != nil { + return "", err + } + if roomInfo == nil { + return "", fmt.Errorf("No room info found for %s", roomID) + } + + switch roomInfo.RoomVersion { + case gomatrixserverlib.RoomVersionPseudoIDs: + return r.DB.GetSenderIDForUser(ctx, roomID, userID) + default: + return spec.SenderID(userID.String()), nil + } } func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { + userID, err := spec.NewUserID(string(senderID), true) + if err == nil { + return userID, nil + } return r.DB.GetUserIDForSender(ctx, roomID, senderID) } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 6d7d933fa..a29bf8cc8 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -189,6 +189,7 @@ type Database interface { ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string, ) (map[string]*types.HeaderedEvent, error) GetOrCreateRoomInfo(ctx context.Context, event gomatrixserverlib.PDU) (*types.RoomInfo, error) + GetOrCreateRoomInfoFromID(ctx context.Context, roomID string) (*types.RoomInfo, error) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) MaybeRedactEvent( diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 85a1ba7a1..f9c9ceabb 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -701,6 +701,31 @@ func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event gomatrixserver }, err } +// GetOrCreateRoomInfo gets or creates a new RoomInfo, which is only safe to use with functions only needing a roomVersion or roomNID. +func (d *Database) GetOrCreateRoomInfoFromID(ctx context.Context, roomID string) (roomInfo *types.RoomInfo, err error) { + roomNID, nidOK := d.Cache.GetRoomServerRoomNID(roomID) + cachedRoomVersion, versionOK := d.Cache.GetRoomVersion(roomID) + // if we found both, the roomNID and version in our cache, no need to query the database + if nidOK && versionOK { + return &types.RoomInfo{ + RoomNID: roomNID, + RoomVersion: cachedRoomVersion, + }, nil + } + + roomInfo, err = d.RoomInfo(ctx, roomID) + if err != nil { + return nil, err + } + if roomInfo == nil { + return nil, fmt.Errorf("Failed to find room info for %s", roomID) + } + + d.Cache.StoreRoomServerRoomID(roomInfo.RoomNID, roomID) + d.Cache.StoreRoomVersion(roomID, roomInfo.RoomVersion) + return roomInfo, nil +} + func (d *Database) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, eventType); err != nil {