Optimise getting local members and membership counts (#3150)

The previous version was getting **ALL** membership events (as
`ClientEvents`, so going through `NewEventFromTrustedJSONWithID`) for a
given room.
Now we are querying only locally joined users as `ClientEvents`, which
should **significantly** reduce allocations.

Take for example a large room with 2k membership events, but only 1
local user - avoiding 1999 `NewEventFromTrustedJSONWithID` calls just to
calculate the `roomSize` which we can also query by other means.

This is also getting called for every `OutputRoomEvent` in the userAPI.

Benchmark with 1 local user and 100 remote users.
```
pkg: github.com/matrix-org/dendrite/userapi/consumers
cpu: 12th Gen Intel(R) Core(TM) i5-12500H
                    │   old.txt   │               new.txt               │
                    │   sec/op    │   sec/op     vs base                │
LocalRoomMembers-16   375.9µ ± 7%   327.6µ ± 6%  -12.85% (p=0.000 n=10)

                    │    old.txt    │               new.txt                │
                    │     B/op      │     B/op      vs base                │
LocalRoomMembers-16   79.426Ki ± 0%   8.507Ki ± 0%  -89.29% (p=0.000 n=10)

                    │   old.txt   │              new.txt               │
                    │  allocs/op  │ allocs/op   vs base                │
LocalRoomMembers-16   1015.0 ± 0%   277.0 ± 0%  -72.71% (p=0.000 n=10)
```
This commit is contained in:
Till 2023-07-13 14:19:08 +02:00 committed by GitHub
parent f12982472c
commit 5267cc0f54
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 109 additions and 19 deletions

View file

@ -227,6 +227,7 @@ type UserRoomserverAPI interface {
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error) PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error)
PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error) PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error)
JoinedUserCount(ctx context.Context, roomID string) (int, error)
} }
type FederationRoomserverAPI interface { type FederationRoomserverAPI interface {

View file

@ -974,6 +974,20 @@ func (r *Queryer) LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixse
return joinedUsers, nil return joinedUsers, nil
} }
func (r *Queryer) JoinedUserCount(ctx context.Context, roomID string) (int, error) {
info, err := r.DB.RoomInfo(ctx, roomID)
if err != nil {
return 0, err
}
if info == nil {
return 0, nil
}
// TODO: this can be further optimised by just using a SELECT COUNT query
nids, err := r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false)
return len(nids), err
}
// nolint:gocyclo // nolint:gocyclo
func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) { func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
// Look up if we know anything about the room. If it doesn't exist // Look up if we know anything about the room. If it doesn't exist

View file

@ -405,18 +405,25 @@ func newLocalMembership(event *synctypes.ClientEvent) (*localMembership, error)
// localRoomMembers fetches the current local members of a room, and // localRoomMembers fetches the current local members of a room, and
// the total number of members. // the total number of members.
func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) { func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) {
// Get only locally joined users to avoid unmarshalling and caching
// membership events we only use to calculate the room size.
req := &rsapi.QueryMembershipsForRoomRequest{ req := &rsapi.QueryMembershipsForRoomRequest{
RoomID: roomID, RoomID: roomID,
JoinedOnly: true, JoinedOnly: true,
LocalOnly: true,
} }
var res rsapi.QueryMembershipsForRoomResponse var res rsapi.QueryMembershipsForRoomResponse
// XXX: This could potentially race if the state for the event is not known yet
// e.g. the event came over federation but we do not have the full state persisted.
if err := s.rsAPI.QueryMembershipsForRoom(ctx, req, &res); err != nil { if err := s.rsAPI.QueryMembershipsForRoom(ctx, req, &res); err != nil {
return nil, 0, err return nil, 0, err
} }
// Since we only queried locally joined users above,
// we also need to ask the roomserver about the joined user count.
totalCount, err := s.rsAPI.JoinedUserCount(ctx, roomID)
if err != nil {
return nil, 0, err
}
var members []*localMembership var members []*localMembership
for _, event := range res.JoinEvents { for _, event := range res.JoinEvents {
// Filter out invalid join events // Filter out invalid join events
@ -426,31 +433,18 @@ func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID s
if *event.StateKey == "" { if *event.StateKey == "" {
continue continue
} }
_, serverName, err := gomatrixserverlib.SplitID('@', *event.StateKey) // We're going to trust the Query from above to really just return
if err != nil { // local users
log.WithError(err).Error("failed to get servername from statekey")
continue
}
// Only get memberships for our server
if serverName != s.serverName {
continue
}
member, err := newLocalMembership(&event) member, err := newLocalMembership(&event)
if err != nil { if err != nil {
log.WithError(err).Errorf("Parsing MemberContent") log.WithError(err).Errorf("Parsing MemberContent")
continue continue
} }
if member.Membership != spec.Join {
continue
}
if member.Domain != s.cfg.Matrix.ServerName {
continue
}
members = append(members, member) members = append(members, member)
} }
return members, len(res.JoinEvents), nil return members, totalCount, nil
} }
// roomName returns the name in the event (if type==m.room.name), or // roomName returns the name in the event (if type==m.room.name), or

View file

@ -2,16 +2,22 @@ package consumers
import ( import (
"context" "context"
"crypto/ed25519"
"reflect" "reflect"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/pushrules"
rsapi "github.com/matrix-org/dendrite/roomserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api"
@ -139,6 +145,42 @@ func Test_evaluatePushRules(t *testing.T) {
}) })
} }
func TestLocalRoomMembers(t *testing.T) {
alice := test.NewUser(t)
_, sk, err := ed25519.GenerateKey(nil)
assert.NoError(t, err)
bob := test.NewUser(t, test.WithSigningServer("notlocalhost", "ed25519:abc", sk))
charlie := test.NewUser(t, test.WithSigningServer("notlocalhost", "ed25519:abc", sk))
room := test.NewRoom(t, alice)
room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]string{"membership": spec.Join}, test.WithStateKey(bob.ID))
room.CreateAndInsert(t, charlie, spec.MRoomMember, map[string]string{"membership": spec.Join}, test.WithStateKey(charlie.ID))
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
defer close()
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
natsInstance := &jetstream.NATSInstance{}
caches := caching.NewRistrettoCache(8*1024*1024, time.Hour, caching.DisableMetrics)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
db, err := storage.NewUserDatabase(processCtx.Context(), cm, &cfg.UserAPI.AccountDatabase, cfg.Global.ServerName, bcrypt.MinCost, 1000, 1000, "")
assert.NoError(t, err)
err = rsapi.SendEvents(processCtx.Context(), rsAPI, rsapi.KindNew, room.Events(), "", "test", "test", nil, false)
assert.NoError(t, err)
consumer := OutputRoomEventConsumer{db: db, rsAPI: rsAPI, serverName: "test", cfg: &cfg.UserAPI}
members, count, err := consumer.localRoomMembers(processCtx.Context(), room.ID)
assert.NoError(t, err)
assert.Equal(t, 3, count)
expectedLocalMember := &localMembership{UserID: alice.ID, Localpart: alice.Localpart, Domain: "test", MemberContent: gomatrixserverlib.MemberContent{Membership: spec.Join}}
assert.Equal(t, expectedLocalMember, members[0])
})
}
func TestMessageStats(t *testing.T) { func TestMessageStats(t *testing.T) {
type args struct { type args struct {
eventType string eventType string
@ -257,3 +299,42 @@ func TestMessageStats(t *testing.T) {
} }
}) })
} }
func BenchmarkLocalRoomMembers(b *testing.B) {
t := &testing.T{}
cfg, processCtx, close := testrig.CreateConfig(t, test.DBTypePostgres)
defer close()
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
natsInstance := &jetstream.NATSInstance{}
caches := caching.NewRistrettoCache(8*1024*1024, time.Hour, caching.DisableMetrics)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
db, err := storage.NewUserDatabase(processCtx.Context(), cm, &cfg.UserAPI.AccountDatabase, cfg.Global.ServerName, bcrypt.MinCost, 1000, 1000, "")
assert.NoError(b, err)
consumer := OutputRoomEventConsumer{db: db, rsAPI: rsAPI, serverName: "test", cfg: &cfg.UserAPI}
_, sk, err := ed25519.GenerateKey(nil)
assert.NoError(b, err)
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
for i := 0; i < 100; i++ {
user := test.NewUser(t, test.WithSigningServer("notlocalhost", "ed25519:abc", sk))
room.CreateAndInsert(t, user, spec.MRoomMember, map[string]string{"membership": spec.Join}, test.WithStateKey(user.ID))
}
err = rsapi.SendEvents(processCtx.Context(), rsAPI, rsapi.KindNew, room.Events(), "", "test", "test", nil, false)
assert.NoError(b, err)
expectedLocalMember := &localMembership{UserID: alice.ID, Localpart: alice.Localpart, Domain: "test", MemberContent: gomatrixserverlib.MemberContent{Membership: spec.Join}}
b.ResetTimer()
for i := 0; i < b.N; i++ {
members, count, err := consumer.localRoomMembers(processCtx.Context(), room.ID)
assert.NoError(b, err)
assert.Equal(b, 101, count)
assert.Equal(b, expectedLocalMember, members[0])
}
}