Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/memberships
This commit is contained in:
commit
16075ce657
2
go.mod
2
go.mod
|
@ -26,7 +26,7 @@ require (
|
||||||
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
|
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
|
||||||
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
|
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
|
||||||
github.com/mattn/go-sqlite3 v1.14.17
|
github.com/mattn/go-sqlite3 v1.14.17
|
||||||
github.com/nats-io/nats-server/v2 v2.9.15
|
github.com/nats-io/nats-server/v2 v2.9.19
|
||||||
github.com/nats-io/nats.go v1.27.0
|
github.com/nats-io/nats.go v1.27.0
|
||||||
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9
|
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9
|
||||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -243,8 +243,8 @@ github.com/mschoch/smat v0.2.0 h1:8imxQsjDm8yFEAVBe7azKmKSgzSkZXDuKkSq9374khM=
|
||||||
github.com/mschoch/smat v0.2.0/go.mod h1:kc9mz7DoBKqDyiRL7VZN8KvXQMWeTaVnttLRXOlotKw=
|
github.com/mschoch/smat v0.2.0/go.mod h1:kc9mz7DoBKqDyiRL7VZN8KvXQMWeTaVnttLRXOlotKw=
|
||||||
github.com/nats-io/jwt/v2 v2.4.1 h1:Y35W1dgbbz2SQUYDPCaclXcuqleVmpbRa7646Jf2EX4=
|
github.com/nats-io/jwt/v2 v2.4.1 h1:Y35W1dgbbz2SQUYDPCaclXcuqleVmpbRa7646Jf2EX4=
|
||||||
github.com/nats-io/jwt/v2 v2.4.1/go.mod h1:24BeQtRwxRV8ruvC4CojXlx/WQ/VjuwlYiH+vu/+ibI=
|
github.com/nats-io/jwt/v2 v2.4.1/go.mod h1:24BeQtRwxRV8ruvC4CojXlx/WQ/VjuwlYiH+vu/+ibI=
|
||||||
github.com/nats-io/nats-server/v2 v2.9.15 h1:MuwEJheIwpvFgqvbs20W8Ish2azcygjf4Z0liVu2I4c=
|
github.com/nats-io/nats-server/v2 v2.9.19 h1:OF9jSKZGo425C/FcVVIvNgpd36CUe7aVTTXEZRJk6kA=
|
||||||
github.com/nats-io/nats-server/v2 v2.9.15/go.mod h1:QlCTy115fqpx4KSOPFIxSV7DdI6OxtZsGOL1JLdeRlE=
|
github.com/nats-io/nats-server/v2 v2.9.19/go.mod h1:aTb/xtLCGKhfTFLxP591CMWfkdgBmcUUSkiSOe5A3gw=
|
||||||
github.com/nats-io/nats.go v1.27.0 h1:3o9fsPhmoKm+yK7rekH2GtWoE+D9jFbw8N3/ayI1C00=
|
github.com/nats-io/nats.go v1.27.0 h1:3o9fsPhmoKm+yK7rekH2GtWoE+D9jFbw8N3/ayI1C00=
|
||||||
github.com/nats-io/nats.go v1.27.0/go.mod h1:XpbWUlOElGwTYbMR7imivs7jJj9GtK7ypv321Wp6pjc=
|
github.com/nats-io/nats.go v1.27.0/go.mod h1:XpbWUlOElGwTYbMR7imivs7jJj9GtK7ypv321Wp6pjc=
|
||||||
github.com/nats-io/nkeys v0.4.4 h1:xvBJ8d69TznjcQl9t6//Q5xXuVhyYiSos6RPtvQNTwA=
|
github.com/nats-io/nkeys v0.4.4 h1:xvBJ8d69TznjcQl9t6//Q5xXuVhyYiSos6RPtvQNTwA=
|
||||||
|
|
|
@ -227,6 +227,7 @@ type UserRoomserverAPI interface {
|
||||||
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
|
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
|
||||||
PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error)
|
PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error)
|
||||||
PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error)
|
PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error)
|
||||||
|
JoinedUserCount(ctx context.Context, roomID string) (int, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type FederationRoomserverAPI interface {
|
type FederationRoomserverAPI interface {
|
||||||
|
|
|
@ -974,6 +974,20 @@ func (r *Queryer) LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixse
|
||||||
return joinedUsers, nil
|
return joinedUsers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Queryer) JoinedUserCount(ctx context.Context, roomID string) (int, error) {
|
||||||
|
info, err := r.DB.RoomInfo(ctx, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if info == nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: this can be further optimised by just using a SELECT COUNT query
|
||||||
|
nids, err := r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false)
|
||||||
|
return len(nids), err
|
||||||
|
}
|
||||||
|
|
||||||
// nolint:gocyclo
|
// nolint:gocyclo
|
||||||
func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
|
func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
|
||||||
// Look up if we know anything about the room. If it doesn't exist
|
// Look up if we know anything about the room. If it doesn't exist
|
||||||
|
|
|
@ -53,6 +53,7 @@ type messagesReq struct {
|
||||||
wasToProvided bool
|
wasToProvided bool
|
||||||
backwardOrdering bool
|
backwardOrdering bool
|
||||||
filter *synctypes.RoomEventFilter
|
filter *synctypes.RoomEventFilter
|
||||||
|
didBackfill bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type messagesResp struct {
|
type messagesResp struct {
|
||||||
|
@ -251,18 +252,19 @@ func OnIncomingMessagesRequest(
|
||||||
}
|
}
|
||||||
|
|
||||||
// If start and end are equal, we either reached the beginning or something else
|
// If start and end are equal, we either reached the beginning or something else
|
||||||
// is wrong. To avoid endless loops from clients, set end to 0 an empty string
|
// is wrong. If we have nothing to return set end to 0.
|
||||||
if start == end {
|
if start == end || len(clientEvents) == 0 {
|
||||||
end = types.TopologyToken{}
|
end = types.TopologyToken{}
|
||||||
}
|
}
|
||||||
|
|
||||||
util.GetLogger(req.Context()).WithFields(logrus.Fields{
|
util.GetLogger(req.Context()).WithFields(logrus.Fields{
|
||||||
"from": from.String(),
|
"request_from": from.String(),
|
||||||
"to": to.String(),
|
"request_to": to.String(),
|
||||||
"limit": filter.Limit,
|
"limit": filter.Limit,
|
||||||
"backwards": backwardOrdering,
|
"backwards": backwardOrdering,
|
||||||
"return_start": start.String(),
|
"response_start": start.String(),
|
||||||
"return_end": end.String(),
|
"response_end": end.String(),
|
||||||
|
"backfilled": mReq.didBackfill,
|
||||||
}).Info("Responding")
|
}).Info("Responding")
|
||||||
|
|
||||||
res := messagesResp{
|
res := messagesResp{
|
||||||
|
@ -284,11 +286,6 @@ func OnIncomingMessagesRequest(
|
||||||
})...)
|
})...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we didn't return any events, set the end to an empty string, so it will be omitted
|
|
||||||
// in the response JSON.
|
|
||||||
if len(res.Chunk) == 0 {
|
|
||||||
res.End = ""
|
|
||||||
}
|
|
||||||
if fromStream != nil {
|
if fromStream != nil {
|
||||||
res.StartStream = fromStream.String()
|
res.StartStream = fromStream.String()
|
||||||
}
|
}
|
||||||
|
@ -328,11 +325,12 @@ func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserv
|
||||||
) {
|
) {
|
||||||
emptyToken := types.TopologyToken{}
|
emptyToken := types.TopologyToken{}
|
||||||
// Retrieve the events from the local database.
|
// Retrieve the events from the local database.
|
||||||
streamEvents, err := r.snapshot.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering)
|
streamEvents, _, end, err := r.snapshot.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("GetEventsInRange: %w", err)
|
err = fmt.Errorf("GetEventsInRange: %w", err)
|
||||||
return []synctypes.ClientEvent{}, emptyToken, emptyToken, err
|
return []synctypes.ClientEvent{}, *r.from, emptyToken, err
|
||||||
}
|
}
|
||||||
|
end.Decrement()
|
||||||
|
|
||||||
var events []*rstypes.HeaderedEvent
|
var events []*rstypes.HeaderedEvent
|
||||||
util.GetLogger(r.ctx).WithFields(logrus.Fields{
|
util.GetLogger(r.ctx).WithFields(logrus.Fields{
|
||||||
|
@ -346,32 +344,54 @@ func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserv
|
||||||
// on the ordering), or we've reached a backward extremity.
|
// on the ordering), or we've reached a backward extremity.
|
||||||
if len(streamEvents) == 0 {
|
if len(streamEvents) == 0 {
|
||||||
if events, err = r.handleEmptyEventsSlice(); err != nil {
|
if events, err = r.handleEmptyEventsSlice(); err != nil {
|
||||||
return []synctypes.ClientEvent{}, emptyToken, emptyToken, err
|
return []synctypes.ClientEvent{}, *r.from, emptyToken, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if events, err = r.handleNonEmptyEventsSlice(streamEvents); err != nil {
|
if events, err = r.handleNonEmptyEventsSlice(streamEvents); err != nil {
|
||||||
return []synctypes.ClientEvent{}, emptyToken, emptyToken, err
|
return []synctypes.ClientEvent{}, *r.from, emptyToken, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we didn't get any event, we don't need to proceed any further.
|
// If we didn't get any event, we don't need to proceed any further.
|
||||||
if len(events) == 0 {
|
if len(events) == 0 {
|
||||||
return []synctypes.ClientEvent{}, *r.from, *r.to, nil
|
return []synctypes.ClientEvent{}, *r.from, emptyToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the position of the first and the last event in the room's topology.
|
// Apply room history visibility filter
|
||||||
// This position is currently determined by the event's depth, so we could
|
startTime := time.Now()
|
||||||
// also use it instead of retrieving from the database. However, if we ever
|
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.snapshot, r.rsAPI, events, nil, r.device.UserID, "messages")
|
||||||
// change the way topological positions are defined (as depth isn't the most
|
|
||||||
// reliable way to define it), it would be easier and less troublesome to
|
|
||||||
// only have to change it in one place, i.e. the database.
|
|
||||||
start, end, err = r.getStartEnd(events)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []synctypes.ClientEvent{}, *r.from, *r.to, err
|
return []synctypes.ClientEvent{}, *r.from, *r.to, nil
|
||||||
|
}
|
||||||
|
logrus.WithFields(logrus.Fields{
|
||||||
|
"duration": time.Since(startTime),
|
||||||
|
"room_id": r.roomID,
|
||||||
|
"events_before": len(events),
|
||||||
|
"events_after": len(filteredEvents),
|
||||||
|
}).Debug("applied history visibility (messages)")
|
||||||
|
|
||||||
|
// No events left after applying history visibility
|
||||||
|
if len(filteredEvents) == 0 {
|
||||||
|
return []synctypes.ClientEvent{}, *r.from, emptyToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we backfilled in the process of getting events, we need
|
||||||
|
// to re-fetch the start/end positions
|
||||||
|
if r.didBackfill {
|
||||||
|
_, end, err = r.getStartEnd(filteredEvents)
|
||||||
|
if err != nil {
|
||||||
|
return []synctypes.ClientEvent{}, *r.from, *r.to, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sort the events to ensure we send them in the right order.
|
// Sort the events to ensure we send them in the right order.
|
||||||
if r.backwardOrdering {
|
if r.backwardOrdering {
|
||||||
|
if events[len(events)-1].Type() == spec.MRoomCreate {
|
||||||
|
// NOTSPEC: We've hit the beginning of the room so there's really nowhere
|
||||||
|
// else to go. This seems to fix Element iOS from looping on /messages endlessly.
|
||||||
|
end = types.TopologyToken{}
|
||||||
|
}
|
||||||
|
|
||||||
// This reverses the array from old->new to new->old
|
// This reverses the array from old->new to new->old
|
||||||
reversed := func(in []*rstypes.HeaderedEvent) []*rstypes.HeaderedEvent {
|
reversed := func(in []*rstypes.HeaderedEvent) []*rstypes.HeaderedEvent {
|
||||||
out := make([]*rstypes.HeaderedEvent, len(in))
|
out := make([]*rstypes.HeaderedEvent, len(in))
|
||||||
|
@ -380,24 +400,14 @@ func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserv
|
||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
events = reversed(events)
|
filteredEvents = reversed(filteredEvents)
|
||||||
}
|
|
||||||
if len(events) == 0 {
|
|
||||||
return []synctypes.ClientEvent{}, *r.from, *r.to, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply room history visibility filter
|
start = *r.from
|
||||||
startTime := time.Now()
|
|
||||||
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.snapshot, r.rsAPI, events, nil, r.device.UserID, "messages")
|
|
||||||
logrus.WithFields(logrus.Fields{
|
|
||||||
"duration": time.Since(startTime),
|
|
||||||
"room_id": r.roomID,
|
|
||||||
"events_before": len(events),
|
|
||||||
"events_after": len(filteredEvents),
|
|
||||||
}).Debug("applied history visibility (messages)")
|
|
||||||
return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}), start, end, err
|
}), start, end, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *messagesReq) getStartEnd(events []*rstypes.HeaderedEvent) (start, end types.TopologyToken, err error) {
|
func (r *messagesReq) getStartEnd(events []*rstypes.HeaderedEvent) (start, end types.TopologyToken, err error) {
|
||||||
|
@ -450,6 +460,7 @@ func (r *messagesReq) handleEmptyEventsSlice() (
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
r.didBackfill = true
|
||||||
} else {
|
} else {
|
||||||
// If not, it means the slice was empty because we reached the room's
|
// If not, it means the slice was empty because we reached the room's
|
||||||
// creation, so return an empty slice.
|
// creation, so return an empty slice.
|
||||||
|
@ -499,7 +510,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
r.didBackfill = true
|
||||||
// Append the PDUs to the list to send back to the client.
|
// Append the PDUs to the list to send back to the client.
|
||||||
events = append(events, pdus...)
|
events = append(events, pdus...)
|
||||||
}
|
}
|
||||||
|
@ -561,15 +572,17 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]
|
||||||
if res.HistoryVisibility == "" {
|
if res.HistoryVisibility == "" {
|
||||||
res.HistoryVisibility = gomatrixserverlib.HistoryVisibilityShared
|
res.HistoryVisibility = gomatrixserverlib.HistoryVisibilityShared
|
||||||
}
|
}
|
||||||
for i := range res.Events {
|
events := res.Events
|
||||||
|
for i := range events {
|
||||||
|
events[i].Visibility = res.HistoryVisibility
|
||||||
_, err = r.db.WriteEvent(
|
_, err = r.db.WriteEvent(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
res.Events[i],
|
events[i],
|
||||||
[]*rstypes.HeaderedEvent{},
|
[]*rstypes.HeaderedEvent{},
|
||||||
[]string{},
|
[]string{},
|
||||||
[]string{},
|
[]string{},
|
||||||
nil, true,
|
nil, true,
|
||||||
res.HistoryVisibility,
|
events[i].Visibility,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -577,14 +590,10 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]
|
||||||
}
|
}
|
||||||
|
|
||||||
// we may have got more than the requested limit so resize now
|
// we may have got more than the requested limit so resize now
|
||||||
events := res.Events
|
|
||||||
if len(events) > limit {
|
if len(events) > limit {
|
||||||
// last `limit` events
|
// last `limit` events
|
||||||
events = events[len(events)-limit:]
|
events = events[len(events)-limit:]
|
||||||
}
|
}
|
||||||
for _, ev := range events {
|
|
||||||
ev.Visibility = res.HistoryVisibility
|
|
||||||
}
|
|
||||||
|
|
||||||
return events, nil
|
return events, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,6 +43,7 @@ func Setup(
|
||||||
cfg *config.SyncAPI,
|
cfg *config.SyncAPI,
|
||||||
lazyLoadCache caching.LazyLoadCache,
|
lazyLoadCache caching.LazyLoadCache,
|
||||||
fts fulltext.Indexer,
|
fts fulltext.Indexer,
|
||||||
|
rateLimits *httputil.RateLimits,
|
||||||
) {
|
) {
|
||||||
v1unstablemux := csMux.PathPrefix("/{apiversion:(?:v1|unstable)}/").Subrouter()
|
v1unstablemux := csMux.PathPrefix("/{apiversion:(?:v1|unstable)}/").Subrouter()
|
||||||
v3mux := csMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter()
|
v3mux := csMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter()
|
||||||
|
@ -53,6 +54,10 @@ func Setup(
|
||||||
}, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions)
|
}, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
v3mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
// not specced, but ensure we're rate limiting requests to this endpoint
|
||||||
|
if r := rateLimits.Limit(req, device); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
|
|
|
@ -81,8 +81,11 @@ type DatabaseTransaction interface {
|
||||||
// If no data is retrieved, returns an empty map
|
// If no data is retrieved, returns an empty map
|
||||||
// If there was an issue with the retrieval, returns an error
|
// If there was an issue with the retrieval, returns an error
|
||||||
GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *synctypes.EventFilter) (map[string][]string, types.StreamPosition, error)
|
GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *synctypes.EventFilter) (map[string][]string, types.StreamPosition, error)
|
||||||
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last.
|
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit.
|
||||||
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *synctypes.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error)
|
// If backwardsOrdering is true, the most recent event must be first, else last.
|
||||||
|
// Returns the filtered StreamEvents on success. Returns **unfiltered** StreamEvents and ErrNoEventsForFilter if
|
||||||
|
// the provided filter removed all events, this can be used to still calculate the start/end position. (e.g for `/messages`)
|
||||||
|
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *synctypes.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, start, end types.TopologyToken, err error)
|
||||||
// EventPositionInTopology returns the depth and stream position of the given event.
|
// EventPositionInTopology returns the depth and stream position of the given event.
|
||||||
EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
|
EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
|
||||||
// BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events.
|
// BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events.
|
||||||
|
|
|
@ -48,14 +48,14 @@ const insertEventInTopologySQL = "" +
|
||||||
" RETURNING topological_position"
|
" RETURNING topological_position"
|
||||||
|
|
||||||
const selectEventIDsInRangeASCSQL = "" +
|
const selectEventIDsInRangeASCSQL = "" +
|
||||||
"SELECT event_id FROM syncapi_output_room_events_topology" +
|
"SELECT event_id, topological_position, stream_position FROM syncapi_output_room_events_topology" +
|
||||||
" WHERE room_id = $1 AND (" +
|
" WHERE room_id = $1 AND (" +
|
||||||
"(topological_position > $2 AND topological_position < $3) OR" +
|
"(topological_position > $2 AND topological_position < $3) OR" +
|
||||||
"(topological_position = $4 AND stream_position >= $5)" +
|
"(topological_position = $4 AND stream_position >= $5)" +
|
||||||
") ORDER BY topological_position ASC, stream_position ASC LIMIT $6"
|
") ORDER BY topological_position ASC, stream_position ASC LIMIT $6"
|
||||||
|
|
||||||
const selectEventIDsInRangeDESCSQL = "" +
|
const selectEventIDsInRangeDESCSQL = "" +
|
||||||
"SELECT event_id FROM syncapi_output_room_events_topology" +
|
"SELECT event_id, topological_position, stream_position FROM syncapi_output_room_events_topology" +
|
||||||
" WHERE room_id = $1 AND (" +
|
" WHERE room_id = $1 AND (" +
|
||||||
"(topological_position > $2 AND topological_position < $3) OR" +
|
"(topological_position > $2 AND topological_position < $3) OR" +
|
||||||
"(topological_position = $4 AND stream_position <= $5)" +
|
"(topological_position = $4 AND stream_position <= $5)" +
|
||||||
|
@ -113,12 +113,13 @@ func (s *outputRoomEventsTopologyStatements) InsertEventInTopology(
|
||||||
}
|
}
|
||||||
|
|
||||||
// SelectEventIDsInRange selects the IDs of events which positions are within a
|
// SelectEventIDsInRange selects the IDs of events which positions are within a
|
||||||
// given range in a given room's topological order.
|
// given range in a given room's topological order. Returns the start/end topological tokens for
|
||||||
|
// the returned eventIDs.
|
||||||
// Returns an empty slice if no events match the given range.
|
// Returns an empty slice if no events match the given range.
|
||||||
func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
|
func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
|
||||||
ctx context.Context, txn *sql.Tx, roomID string, minDepth, maxDepth, maxStreamPos types.StreamPosition,
|
ctx context.Context, txn *sql.Tx, roomID string, minDepth, maxDepth, maxStreamPos types.StreamPosition,
|
||||||
limit int, chronologicalOrder bool,
|
limit int, chronologicalOrder bool,
|
||||||
) (eventIDs []string, err error) {
|
) (eventIDs []string, start, end types.TopologyToken, err error) {
|
||||||
// Decide on the selection's order according to whether chronological order
|
// Decide on the selection's order according to whether chronological order
|
||||||
// is requested or not.
|
// is requested or not.
|
||||||
var stmt *sql.Stmt
|
var stmt *sql.Stmt
|
||||||
|
@ -132,7 +133,7 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
|
||||||
rows, err := stmt.QueryContext(ctx, roomID, minDepth, maxDepth, maxDepth, maxStreamPos, limit)
|
rows, err := stmt.QueryContext(ctx, roomID, minDepth, maxDepth, maxDepth, maxStreamPos, limit)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
// If no event matched the request, return an empty slice.
|
// If no event matched the request, return an empty slice.
|
||||||
return []string{}, nil
|
return []string{}, start, end, nil
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -140,14 +141,23 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
|
||||||
|
|
||||||
// Return the IDs.
|
// Return the IDs.
|
||||||
var eventID string
|
var eventID string
|
||||||
|
var token types.TopologyToken
|
||||||
|
var tokens []types.TopologyToken
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
if err = rows.Scan(&eventID); err != nil {
|
if err = rows.Scan(&eventID, &token.Depth, &token.PDUPosition); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
eventIDs = append(eventIDs, eventID)
|
eventIDs = append(eventIDs, eventID)
|
||||||
|
tokens = append(tokens, token)
|
||||||
}
|
}
|
||||||
|
|
||||||
return eventIDs, rows.Err()
|
// The values are already ordered by SQL, so we can use them as is.
|
||||||
|
if len(tokens) > 0 {
|
||||||
|
start = tokens[0]
|
||||||
|
end = tokens[len(tokens)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
return eventIDs, start, end, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
// SelectPositionInTopology returns the position of a given event in the
|
// SelectPositionInTopology returns the position of a given event in the
|
||||||
|
|
|
@ -237,7 +237,7 @@ func (d *DatabaseTransaction) GetEventsInTopologicalRange(
|
||||||
roomID string,
|
roomID string,
|
||||||
filter *synctypes.RoomEventFilter,
|
filter *synctypes.RoomEventFilter,
|
||||||
backwardOrdering bool,
|
backwardOrdering bool,
|
||||||
) (events []types.StreamEvent, err error) {
|
) (events []types.StreamEvent, start, end types.TopologyToken, err error) {
|
||||||
var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition
|
var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition
|
||||||
if backwardOrdering {
|
if backwardOrdering {
|
||||||
// Backward ordering means the 'from' token has a higher depth than the 'to' token
|
// Backward ordering means the 'from' token has a higher depth than the 'to' token
|
||||||
|
@ -255,7 +255,7 @@ func (d *DatabaseTransaction) GetEventsInTopologicalRange(
|
||||||
|
|
||||||
// Select the event IDs from the defined range.
|
// Select the event IDs from the defined range.
|
||||||
var eIDs []string
|
var eIDs []string
|
||||||
eIDs, err = d.Topology.SelectEventIDsInRange(
|
eIDs, start, end, err = d.Topology.SelectEventIDsInRange(
|
||||||
ctx, d.txn, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering,
|
ctx, d.txn, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -264,6 +264,10 @@ func (d *DatabaseTransaction) GetEventsInTopologicalRange(
|
||||||
|
|
||||||
// Retrieve the events' contents using their IDs.
|
// Retrieve the events' contents using their IDs.
|
||||||
events, err = d.OutputEvents.SelectEvents(ctx, d.txn, eIDs, filter, true)
|
events, err = d.OutputEvents.SelectEvents(ctx, d.txn, eIDs, filter, true)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -44,14 +44,14 @@ const insertEventInTopologySQL = "" +
|
||||||
" ON CONFLICT DO NOTHING"
|
" ON CONFLICT DO NOTHING"
|
||||||
|
|
||||||
const selectEventIDsInRangeASCSQL = "" +
|
const selectEventIDsInRangeASCSQL = "" +
|
||||||
"SELECT event_id FROM syncapi_output_room_events_topology" +
|
"SELECT event_id, topological_position, stream_position FROM syncapi_output_room_events_topology" +
|
||||||
" WHERE room_id = $1 AND (" +
|
" WHERE room_id = $1 AND (" +
|
||||||
"(topological_position > $2 AND topological_position < $3) OR" +
|
"(topological_position > $2 AND topological_position < $3) OR" +
|
||||||
"(topological_position = $4 AND stream_position >= $5)" +
|
"(topological_position = $4 AND stream_position >= $5)" +
|
||||||
") ORDER BY topological_position ASC, stream_position ASC LIMIT $6"
|
") ORDER BY topological_position ASC, stream_position ASC LIMIT $6"
|
||||||
|
|
||||||
const selectEventIDsInRangeDESCSQL = "" +
|
const selectEventIDsInRangeDESCSQL = "" +
|
||||||
"SELECT event_id FROM syncapi_output_room_events_topology" +
|
"SELECT event_id, topological_position, stream_position FROM syncapi_output_room_events_topology" +
|
||||||
" WHERE room_id = $1 AND (" +
|
" WHERE room_id = $1 AND (" +
|
||||||
"(topological_position > $2 AND topological_position < $3) OR" +
|
"(topological_position > $2 AND topological_position < $3) OR" +
|
||||||
"(topological_position = $4 AND stream_position <= $5)" +
|
"(topological_position = $4 AND stream_position <= $5)" +
|
||||||
|
@ -111,11 +111,15 @@ func (s *outputRoomEventsTopologyStatements) InsertEventInTopology(
|
||||||
return types.StreamPosition(event.Depth()), err
|
return types.StreamPosition(event.Depth()), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SelectEventIDsInRange selects the IDs of events which positions are within a
|
||||||
|
// given range in a given room's topological order. Returns the start/end topological tokens for
|
||||||
|
// the returned eventIDs.
|
||||||
|
// Returns an empty slice if no events match the given range.
|
||||||
func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
|
func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
|
||||||
ctx context.Context, txn *sql.Tx, roomID string,
|
ctx context.Context, txn *sql.Tx, roomID string,
|
||||||
minDepth, maxDepth, maxStreamPos types.StreamPosition,
|
minDepth, maxDepth, maxStreamPos types.StreamPosition,
|
||||||
limit int, chronologicalOrder bool,
|
limit int, chronologicalOrder bool,
|
||||||
) (eventIDs []string, err error) {
|
) (eventIDs []string, start, end types.TopologyToken, err error) {
|
||||||
// Decide on the selection's order according to whether chronological order
|
// Decide on the selection's order according to whether chronological order
|
||||||
// is requested or not.
|
// is requested or not.
|
||||||
var stmt *sql.Stmt
|
var stmt *sql.Stmt
|
||||||
|
@ -129,18 +133,27 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
|
||||||
rows, err := stmt.QueryContext(ctx, roomID, minDepth, maxDepth, maxDepth, maxStreamPos, limit)
|
rows, err := stmt.QueryContext(ctx, roomID, minDepth, maxDepth, maxDepth, maxStreamPos, limit)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
// If no event matched the request, return an empty slice.
|
// If no event matched the request, return an empty slice.
|
||||||
return []string{}, nil
|
return []string{}, start, end, nil
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the IDs.
|
// Return the IDs.
|
||||||
var eventID string
|
var eventID string
|
||||||
|
var token types.TopologyToken
|
||||||
|
var tokens []types.TopologyToken
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
if err = rows.Scan(&eventID); err != nil {
|
if err = rows.Scan(&eventID, &token.Depth, &token.PDUPosition); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
eventIDs = append(eventIDs, eventID)
|
eventIDs = append(eventIDs, eventID)
|
||||||
|
tokens = append(tokens, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The values are already ordered by SQL, so we can use them as is.
|
||||||
|
if len(tokens) > 0 {
|
||||||
|
start = tokens[0]
|
||||||
|
end = tokens[len(tokens)-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
|
@ -213,12 +213,48 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
|
||||||
|
|
||||||
// backpaginate 5 messages starting at the latest position.
|
// backpaginate 5 messages starting at the latest position.
|
||||||
filter := &synctypes.RoomEventFilter{Limit: 5}
|
filter := &synctypes.RoomEventFilter{Limit: 5}
|
||||||
paginatedEvents, err := snapshot.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true)
|
paginatedEvents, start, end, err := snapshot.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
|
t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
|
||||||
}
|
}
|
||||||
gots := snapshot.StreamEventsToEvents(context.Background(), nil, paginatedEvents, nil)
|
gots := snapshot.StreamEventsToEvents(context.Background(), nil, paginatedEvents, nil)
|
||||||
test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
|
test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
|
||||||
|
assert.Equal(t, types.TopologyToken{Depth: 15, PDUPosition: 15}, start)
|
||||||
|
assert.Equal(t, types.TopologyToken{Depth: 11, PDUPosition: 11}, end)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// The purpose of this test is to ensure that backfilling returns no start/end if a given filter removes
|
||||||
|
// all events.
|
||||||
|
func TestGetEventsInRangeWithTopologyTokenNoEventsForFilter(t *testing.T) {
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, close := MustCreateDatabase(t, dbType)
|
||||||
|
defer close()
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
r := test.NewRoom(t, alice)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)})
|
||||||
|
}
|
||||||
|
events := r.Events()
|
||||||
|
_ = MustWriteEvents(t, db, events)
|
||||||
|
|
||||||
|
WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
|
||||||
|
from := types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64}
|
||||||
|
t.Logf("max topo pos = %+v", from)
|
||||||
|
// head towards the beginning of time
|
||||||
|
to := types.TopologyToken{}
|
||||||
|
|
||||||
|
// backpaginate 20 messages starting at the latest position.
|
||||||
|
notTypes := []string{spec.MRoomRedaction}
|
||||||
|
senders := []string{alice.ID}
|
||||||
|
filter := &synctypes.RoomEventFilter{Limit: 20, NotTypes: ¬Types, Senders: &senders}
|
||||||
|
paginatedEvents, start, end, err := snapshot.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, len(paginatedEvents))
|
||||||
|
// Even if we didn't get anything back due to the filter, we should still have start/end
|
||||||
|
assert.Equal(t, types.TopologyToken{Depth: 15, PDUPosition: 15}, start)
|
||||||
|
assert.Equal(t, types.TopologyToken{Depth: 1, PDUPosition: 1}, end)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -89,11 +89,11 @@ type Topology interface {
|
||||||
// InsertEventInTopology inserts the given event in the room's topology, based on the event's depth.
|
// InsertEventInTopology inserts the given event in the room's topology, based on the event's depth.
|
||||||
// `pos` is the stream position of this event in the events table, and is used to order events which have the same depth.
|
// `pos` is the stream position of this event in the events table, and is used to order events which have the same depth.
|
||||||
InsertEventInTopology(ctx context.Context, txn *sql.Tx, event *rstypes.HeaderedEvent, pos types.StreamPosition) (topoPos types.StreamPosition, err error)
|
InsertEventInTopology(ctx context.Context, txn *sql.Tx, event *rstypes.HeaderedEvent, pos types.StreamPosition) (topoPos types.StreamPosition, err error)
|
||||||
// SelectEventIDsInRange selects the IDs of events whose depths are within a given range in a given room's topological order.
|
// SelectEventIDsInRange selects the IDs and the topological position of events whose depths are within a given range in a given room's topological order.
|
||||||
// Events with `minDepth` are *exclusive*, as is the event which has exactly `minDepth`,`maxStreamPos`.
|
// Events with `minDepth` are *exclusive*, as is the event which has exactly `minDepth`,`maxStreamPos`. Returns the eventIDs and start/end topological tokens.
|
||||||
// `maxStreamPos` is only used when events have the same depth as `maxDepth`, which results in events less than `maxStreamPos` being returned.
|
// `maxStreamPos` is only used when events have the same depth as `maxDepth`, which results in events less than `maxStreamPos` being returned.
|
||||||
// Returns an empty slice if no events match the given range.
|
// Returns an empty slice if no events match the given range.
|
||||||
SelectEventIDsInRange(ctx context.Context, txn *sql.Tx, roomID string, minDepth, maxDepth, maxStreamPos types.StreamPosition, limit int, chronologicalOrder bool) (eventIDs []string, err error)
|
SelectEventIDsInRange(ctx context.Context, txn *sql.Tx, roomID string, minDepth, maxDepth, maxStreamPos types.StreamPosition, limit int, chronologicalOrder bool) (eventIDs []string, start, end types.TopologyToken, err error)
|
||||||
// SelectPositionInTopology returns the depth and stream position of a given event in the topology of the room it belongs to.
|
// SelectPositionInTopology returns the depth and stream position of a given event in the topology of the room it belongs to.
|
||||||
SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error)
|
SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error)
|
||||||
// SelectStreamToTopologicalPosition converts a stream position to a topological position by finding the nearest topological position in the room.
|
// SelectStreamToTopologicalPosition converts a stream position to a topological position by finding the nearest topological position in the room.
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
"github.com/matrix-org/dendrite/test"
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newTopologyTable(t *testing.T, dbType test.DBType) (tables.Topology, *sql.DB, func()) {
|
func newTopologyTable(t *testing.T, dbType test.DBType) (tables.Topology, *sql.DB, func()) {
|
||||||
|
@ -60,28 +61,37 @@ func TestTopologyTable(t *testing.T) {
|
||||||
highestPos = topoPos + 1
|
highestPos = topoPos + 1
|
||||||
}
|
}
|
||||||
// check ordering works without limit
|
// check ordering works without limit
|
||||||
eventIDs, err := tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, true)
|
eventIDs, start, end, err := tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, true)
|
||||||
if err != nil {
|
assert.NoError(t, err, "failed to SelectEventIDsInRange")
|
||||||
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
|
|
||||||
}
|
|
||||||
test.AssertEventIDsEqual(t, eventIDs, events[:])
|
test.AssertEventIDsEqual(t, eventIDs, events[:])
|
||||||
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, false)
|
assert.Equal(t, types.TopologyToken{Depth: 1, PDUPosition: 0}, start)
|
||||||
if err != nil {
|
assert.Equal(t, types.TopologyToken{Depth: 5, PDUPosition: 4}, end)
|
||||||
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
|
|
||||||
}
|
|
||||||
test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[:]))
|
|
||||||
// check ordering works with limit
|
|
||||||
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, true)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
|
|
||||||
}
|
|
||||||
test.AssertEventIDsEqual(t, eventIDs, events[:3])
|
|
||||||
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, false)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
|
|
||||||
}
|
|
||||||
test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[len(events)-3:]))
|
|
||||||
|
|
||||||
|
eventIDs, start, end, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, false)
|
||||||
|
assert.NoError(t, err, "failed to SelectEventIDsInRange")
|
||||||
|
test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[:]))
|
||||||
|
assert.Equal(t, types.TopologyToken{Depth: 5, PDUPosition: 4}, start)
|
||||||
|
assert.Equal(t, types.TopologyToken{Depth: 1, PDUPosition: 0}, end)
|
||||||
|
|
||||||
|
// check ordering works with limit
|
||||||
|
eventIDs, start, end, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, true)
|
||||||
|
assert.NoError(t, err, "failed to SelectEventIDsInRange")
|
||||||
|
test.AssertEventIDsEqual(t, eventIDs, events[:3])
|
||||||
|
assert.Equal(t, types.TopologyToken{Depth: 1, PDUPosition: 0}, start)
|
||||||
|
assert.Equal(t, types.TopologyToken{Depth: 3, PDUPosition: 2}, end)
|
||||||
|
|
||||||
|
eventIDs, start, end, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, false)
|
||||||
|
assert.NoError(t, err, "failed to SelectEventIDsInRange")
|
||||||
|
test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[len(events)-3:]))
|
||||||
|
assert.Equal(t, types.TopologyToken{Depth: 5, PDUPosition: 4}, start)
|
||||||
|
assert.Equal(t, types.TopologyToken{Depth: 3, PDUPosition: 2}, end)
|
||||||
|
|
||||||
|
// Check that we return no values for invalid rooms
|
||||||
|
eventIDs, start, end, err = tab.SelectEventIDsInRange(ctx, txn, "!doesnotexist:localhost", 0, highestPos, highestPos, 10, false)
|
||||||
|
assert.NoError(t, err, "failed to SelectEventIDsInRange")
|
||||||
|
assert.Equal(t, 0, len(eventIDs))
|
||||||
|
assert.Equal(t, types.TopologyToken{}, start)
|
||||||
|
assert.Equal(t, types.TopologyToken{}, end)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -144,8 +144,11 @@ func AddPublicRoutes(
|
||||||
logrus.WithError(err).Panicf("failed to start receipts consumer")
|
logrus.WithError(err).Panicf("failed to start receipts consumer")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rateLimits := httputil.NewRateLimits(&dendriteCfg.ClientAPI.RateLimiting)
|
||||||
|
|
||||||
routing.Setup(
|
routing.Setup(
|
||||||
routers.Client, requestPool, syncDB, userAPI,
|
routers.Client, requestPool, syncDB, userAPI,
|
||||||
rsAPI, &dendriteCfg.SyncAPI, caches, fts,
|
rsAPI, &dendriteCfg.SyncAPI, caches, fts,
|
||||||
|
rateLimits,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -433,6 +433,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
||||||
|
cfg.ClientAPI.RateLimiting = config.RateLimiting{Enabled: false}
|
||||||
routers := httputil.NewRouters()
|
routers := httputil.NewRouters()
|
||||||
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||||
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||||
|
|
|
@ -405,18 +405,25 @@ func newLocalMembership(event *synctypes.ClientEvent) (*localMembership, error)
|
||||||
// localRoomMembers fetches the current local members of a room, and
|
// localRoomMembers fetches the current local members of a room, and
|
||||||
// the total number of members.
|
// the total number of members.
|
||||||
func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) {
|
func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) {
|
||||||
|
// Get only locally joined users to avoid unmarshalling and caching
|
||||||
|
// membership events we only use to calculate the room size.
|
||||||
req := &rsapi.QueryMembershipsForRoomRequest{
|
req := &rsapi.QueryMembershipsForRoomRequest{
|
||||||
RoomID: roomID,
|
RoomID: roomID,
|
||||||
JoinedOnly: true,
|
JoinedOnly: true,
|
||||||
|
LocalOnly: true,
|
||||||
}
|
}
|
||||||
var res rsapi.QueryMembershipsForRoomResponse
|
var res rsapi.QueryMembershipsForRoomResponse
|
||||||
|
|
||||||
// XXX: This could potentially race if the state for the event is not known yet
|
|
||||||
// e.g. the event came over federation but we do not have the full state persisted.
|
|
||||||
if err := s.rsAPI.QueryMembershipsForRoom(ctx, req, &res); err != nil {
|
if err := s.rsAPI.QueryMembershipsForRoom(ctx, req, &res); err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Since we only queried locally joined users above,
|
||||||
|
// we also need to ask the roomserver about the joined user count.
|
||||||
|
totalCount, err := s.rsAPI.JoinedUserCount(ctx, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
var members []*localMembership
|
var members []*localMembership
|
||||||
for _, event := range res.JoinEvents {
|
for _, event := range res.JoinEvents {
|
||||||
// Filter out invalid join events
|
// Filter out invalid join events
|
||||||
|
@ -426,31 +433,18 @@ func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID s
|
||||||
if *event.StateKey == "" {
|
if *event.StateKey == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
_, serverName, err := gomatrixserverlib.SplitID('@', *event.StateKey)
|
// We're going to trust the Query from above to really just return
|
||||||
if err != nil {
|
// local users
|
||||||
log.WithError(err).Error("failed to get servername from statekey")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Only get memberships for our server
|
|
||||||
if serverName != s.serverName {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
member, err := newLocalMembership(&event)
|
member, err := newLocalMembership(&event)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Errorf("Parsing MemberContent")
|
log.WithError(err).Errorf("Parsing MemberContent")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if member.Membership != spec.Join {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if member.Domain != s.cfg.Matrix.ServerName {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
members = append(members, member)
|
members = append(members, member)
|
||||||
}
|
}
|
||||||
|
|
||||||
return members, len(res.JoinEvents), nil
|
return members, totalCount, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// roomName returns the name in the event (if type==m.room.name), or
|
// roomName returns the name in the event (if type==m.room.name), or
|
||||||
|
|
|
@ -2,16 +2,22 @@ package consumers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||||
|
"github.com/matrix-org/dendrite/test/testrig"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/pushrules"
|
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||||
rsapi "github.com/matrix-org/dendrite/roomserver/api"
|
rsapi "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
@ -139,6 +145,42 @@ func Test_evaluatePushRules(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLocalRoomMembers(t *testing.T) {
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
_, sk, err := ed25519.GenerateKey(nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
bob := test.NewUser(t, test.WithSigningServer("notlocalhost", "ed25519:abc", sk))
|
||||||
|
charlie := test.NewUser(t, test.WithSigningServer("notlocalhost", "ed25519:abc", sk))
|
||||||
|
|
||||||
|
room := test.NewRoom(t, alice)
|
||||||
|
room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]string{"membership": spec.Join}, test.WithStateKey(bob.ID))
|
||||||
|
room.CreateAndInsert(t, charlie, spec.MRoomMember, map[string]string{"membership": spec.Join}, test.WithStateKey(charlie.ID))
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
||||||
|
defer close()
|
||||||
|
|
||||||
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||||
|
natsInstance := &jetstream.NATSInstance{}
|
||||||
|
caches := caching.NewRistrettoCache(8*1024*1024, time.Hour, caching.DisableMetrics)
|
||||||
|
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics)
|
||||||
|
rsAPI.SetFederationAPI(nil, nil)
|
||||||
|
db, err := storage.NewUserDatabase(processCtx.Context(), cm, &cfg.UserAPI.AccountDatabase, cfg.Global.ServerName, bcrypt.MinCost, 1000, 1000, "")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = rsapi.SendEvents(processCtx.Context(), rsAPI, rsapi.KindNew, room.Events(), "", "test", "test", nil, false)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
consumer := OutputRoomEventConsumer{db: db, rsAPI: rsAPI, serverName: "test", cfg: &cfg.UserAPI}
|
||||||
|
members, count, err := consumer.localRoomMembers(processCtx.Context(), room.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 3, count)
|
||||||
|
expectedLocalMember := &localMembership{UserID: alice.ID, Localpart: alice.Localpart, Domain: "test", MemberContent: gomatrixserverlib.MemberContent{Membership: spec.Join}}
|
||||||
|
assert.Equal(t, expectedLocalMember, members[0])
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func TestMessageStats(t *testing.T) {
|
func TestMessageStats(t *testing.T) {
|
||||||
type args struct {
|
type args struct {
|
||||||
eventType string
|
eventType string
|
||||||
|
@ -257,3 +299,42 @@ func TestMessageStats(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BenchmarkLocalRoomMembers(b *testing.B) {
|
||||||
|
t := &testing.T{}
|
||||||
|
|
||||||
|
cfg, processCtx, close := testrig.CreateConfig(t, test.DBTypePostgres)
|
||||||
|
defer close()
|
||||||
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||||
|
natsInstance := &jetstream.NATSInstance{}
|
||||||
|
caches := caching.NewRistrettoCache(8*1024*1024, time.Hour, caching.DisableMetrics)
|
||||||
|
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics)
|
||||||
|
rsAPI.SetFederationAPI(nil, nil)
|
||||||
|
db, err := storage.NewUserDatabase(processCtx.Context(), cm, &cfg.UserAPI.AccountDatabase, cfg.Global.ServerName, bcrypt.MinCost, 1000, 1000, "")
|
||||||
|
assert.NoError(b, err)
|
||||||
|
|
||||||
|
consumer := OutputRoomEventConsumer{db: db, rsAPI: rsAPI, serverName: "test", cfg: &cfg.UserAPI}
|
||||||
|
_, sk, err := ed25519.GenerateKey(nil)
|
||||||
|
assert.NoError(b, err)
|
||||||
|
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
room := test.NewRoom(t, alice)
|
||||||
|
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
user := test.NewUser(t, test.WithSigningServer("notlocalhost", "ed25519:abc", sk))
|
||||||
|
room.CreateAndInsert(t, user, spec.MRoomMember, map[string]string{"membership": spec.Join}, test.WithStateKey(user.ID))
|
||||||
|
}
|
||||||
|
|
||||||
|
err = rsapi.SendEvents(processCtx.Context(), rsAPI, rsapi.KindNew, room.Events(), "", "test", "test", nil, false)
|
||||||
|
assert.NoError(b, err)
|
||||||
|
|
||||||
|
expectedLocalMember := &localMembership{UserID: alice.ID, Localpart: alice.Localpart, Domain: "test", MemberContent: gomatrixserverlib.MemberContent{Membership: spec.Join}}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
members, count, err := consumer.localRoomMembers(processCtx.Context(), room.ID)
|
||||||
|
assert.NoError(b, err)
|
||||||
|
assert.Equal(b, 101, count)
|
||||||
|
assert.Equal(b, expectedLocalMember, members[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue