From a2c4974d9a47f44bb7f952d6eff0a7909b0e0726 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Tue, 11 Jul 2023 16:19:49 +0200 Subject: [PATCH] Optimise getting local members and membership counts --- roomserver/api/api.go | 1 + roomserver/internal/query/query.go | 14 ++++++++++ userapi/consumers/roomserver.go | 25 ++++++----------- userapi/consumers/roomserver_test.go | 42 ++++++++++++++++++++++++++++ 4 files changed, 66 insertions(+), 16 deletions(-) diff --git a/roomserver/api/api.go b/roomserver/api/api.go index ab56529c5..c29406a1a 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -227,6 +227,7 @@ type UserRoomserverAPI interface { QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error) PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error) + JoinedUserCount(ctx context.Context, roomID string) (int, error) } type FederationRoomserverAPI interface { diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 626d3c13e..39e3bd0ec 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -974,6 +974,20 @@ func (r *Queryer) LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixse return joinedUsers, nil } +func (r *Queryer) JoinedUserCount(ctx context.Context, roomID string) (int, error) { + info, err := r.DB.RoomInfo(ctx, roomID) + if err != nil { + return 0, err + } + if info == nil { + return 0, nil + } + + // TODO: this can be further optimised by just using a SELECT COUNT query + nids, err := r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false) + return len(nids), err +} + // nolint:gocyclo func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) { // Look up if we know anything about the room. If it doesn't exist diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 9a9a407ce..1d0956811 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -408,6 +408,7 @@ func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID s req := &rsapi.QueryMembershipsForRoomRequest{ RoomID: roomID, JoinedOnly: true, + LocalOnly: true, } var res rsapi.QueryMembershipsForRoomResponse @@ -417,6 +418,11 @@ func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID s return nil, 0, err } + totalCount, err := s.rsAPI.JoinedUserCount(ctx, roomID) + if err != nil { + return nil, 0, err + } + var members []*localMembership for _, event := range res.JoinEvents { // Filter out invalid join events @@ -426,31 +432,18 @@ func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID s if *event.StateKey == "" { continue } - _, serverName, err := gomatrixserverlib.SplitID('@', *event.StateKey) - if err != nil { - log.WithError(err).Error("failed to get servername from statekey") - continue - } - // Only get memberships for our server - if serverName != s.serverName { - continue - } + // We're going to trust the Query from above to really just return + // local users member, err := newLocalMembership(&event) if err != nil { log.WithError(err).Errorf("Parsing MemberContent") continue } - if member.Membership != spec.Join { - continue - } - if member.Domain != s.cfg.Matrix.ServerName { - continue - } members = append(members, member) } - return members, len(res.JoinEvents), nil + return members, totalCount, nil } // roomName returns the name in the event (if type==m.room.name), or diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index 4dc81e74a..c5e68bc05 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -2,16 +2,22 @@ package consumers import ( "context" + "crypto/ed25519" "reflect" "sync" "testing" "time" + "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" + "golang.org/x/crypto/bcrypt" "github.com/matrix-org/dendrite/internal/pushrules" rsapi "github.com/matrix-org/dendrite/roomserver/api" @@ -139,6 +145,42 @@ func Test_evaluatePushRules(t *testing.T) { }) } +func TestLocalRoomMembers(t *testing.T) { + alice := test.NewUser(t) + _, sk, err := ed25519.GenerateKey(nil) + assert.NoError(t, err) + bob := test.NewUser(t, test.WithSigningServer("notlocalhost", "ed25519:abc", sk)) + charlie := test.NewUser(t, test.WithSigningServer("notlocalhost", "ed25519:abc", sk)) + + room := test.NewRoom(t, alice) + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]string{"membership": spec.Join}, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, charlie, spec.MRoomMember, map[string]string{"membership": spec.Join}, test.WithStateKey(charlie.ID)) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + natsInstance := &jetstream.NATSInstance{} + caches := caching.NewRistrettoCache(8*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + db, err := storage.NewUserDatabase(processCtx.Context(), cm, &cfg.UserAPI.AccountDatabase, cfg.Global.ServerName, bcrypt.MinCost, 1000, 1000, "") + assert.NoError(t, err) + + err = rsapi.SendEvents(processCtx.Context(), rsAPI, rsapi.KindNew, room.Events(), "", "test", "test", nil, false) + assert.NoError(t, err) + + consumer := OutputRoomEventConsumer{db: db, rsAPI: rsAPI, serverName: "test", cfg: &cfg.UserAPI} + members, count, err := consumer.localRoomMembers(processCtx.Context(), room.ID) + assert.NoError(t, err) + assert.Equal(t, 3, count) + expectedLocalMember := &localMembership{UserID: alice.ID, Localpart: alice.Localpart, Domain: "test", MemberContent: gomatrixserverlib.MemberContent{Membership: spec.Join}} + assert.Equal(t, expectedLocalMember, members[0]) + }) + +} + func TestMessageStats(t *testing.T) { type args struct { eventType string