Merge branch 'matrix-org-main'

This commit is contained in:
Brian Meek 2023-02-07 17:30:09 -08:00
commit c74798b5ab
No known key found for this signature in database
GPG key ID: ACBD71263BF42D00
28 changed files with 843 additions and 221 deletions

View file

@ -273,7 +273,11 @@ jobs:
- name: Setup go - name: Setup go
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:
<<<<<<< HEAD
go-version: "1.19" go-version: "1.19"
=======
go-version: "1.18"
>>>>>>> eb29a315507f0075c2c6a495ac59c64a7f45f9fc
cache: true cache: true
- name: Build upgrade-tests - name: Build upgrade-tests
run: go build ./cmd/dendrite-upgrade-tests run: go build ./cmd/dendrite-upgrade-tests
@ -293,7 +297,11 @@ jobs:
- name: Setup go - name: Setup go
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:
<<<<<<< HEAD
go-version: "1.19" go-version: "1.19"
=======
go-version: "1.18"
>>>>>>> eb29a315507f0075c2c6a495ac59c64a7f45f9fc
cache: true cache: true
- name: Build upgrade-tests - name: Build upgrade-tests
run: go build ./cmd/dendrite-upgrade-tests run: go build ./cmd/dendrite-upgrade-tests

View file

@ -433,7 +433,7 @@ func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error {
return nil return nil
} }
// QueryMembershipAtEventRequest requests the membership events for a user // QueryMembershipAtEventRequest requests the membership event for a user
// for a list of eventIDs. // for a list of eventIDs.
type QueryMembershipAtEventRequest struct { type QueryMembershipAtEventRequest struct {
RoomID string RoomID string
@ -443,9 +443,10 @@ type QueryMembershipAtEventRequest struct {
// QueryMembershipAtEventResponse is the response to QueryMembershipAtEventRequest. // QueryMembershipAtEventResponse is the response to QueryMembershipAtEventRequest.
type QueryMembershipAtEventResponse struct { type QueryMembershipAtEventResponse struct {
// Memberships is a map from eventID to a list of events (if any). Events that // Membership is a map from eventID to membership event. Events that
// do not have known state will return an empty array here. // do not have known state will return a nil event, resulting in a "leave" membership
Memberships map[string][]*gomatrixserverlib.HeaderedEvent `json:"memberships"` // when calculating history visibility.
Membership map[string]*gomatrixserverlib.HeaderedEvent `json:"membership"`
} }
// QueryLeftUsersRequest is a request to calculate users that we (the server) don't share a // QueryLeftUsersRequest is a request to calculate users that we (the server) don't share a

View file

@ -21,6 +21,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -216,7 +217,8 @@ func (r *Queryer) QueryMembershipAtEvent(
request *api.QueryMembershipAtEventRequest, request *api.QueryMembershipAtEventRequest,
response *api.QueryMembershipAtEventResponse, response *api.QueryMembershipAtEventResponse,
) error { ) error {
response.Memberships = make(map[string][]*gomatrixserverlib.HeaderedEvent) response.Membership = make(map[string]*gomatrixserverlib.HeaderedEvent)
info, err := r.DB.RoomInfo(ctx, request.RoomID) info, err := r.DB.RoomInfo(ctx, request.RoomID)
if err != nil { if err != nil {
return fmt.Errorf("unable to get roomInfo: %w", err) return fmt.Errorf("unable to get roomInfo: %w", err)
@ -234,7 +236,17 @@ func (r *Queryer) QueryMembershipAtEvent(
return fmt.Errorf("requested stateKeyNID for %s was not found", request.UserID) return fmt.Errorf("requested stateKeyNID for %s was not found", request.UserID)
} }
stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, info, request.EventIDs, stateKeyNIDs[request.UserID]) response.Membership, err = r.DB.GetMembershipForHistoryVisibility(ctx, stateKeyNIDs[request.UserID], info, request.EventIDs...)
switch err {
case nil:
return nil
case tables.OptimisationNotSupportedError: // fallthrough, slow way of getting the membership events for each event
default:
return err
}
response.Membership = make(map[string]*gomatrixserverlib.HeaderedEvent)
stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID])
if err != nil { if err != nil {
return fmt.Errorf("unable to get state before event: %w", err) return fmt.Errorf("unable to get state before event: %w", err)
} }
@ -258,7 +270,7 @@ func (r *Queryer) QueryMembershipAtEvent(
for _, eventID := range request.EventIDs { for _, eventID := range request.EventIDs {
stateEntry, ok := stateEntries[eventID] stateEntry, ok := stateEntries[eventID]
if !ok || len(stateEntry) == 0 { if !ok || len(stateEntry) == 0 {
response.Memberships[eventID] = []*gomatrixserverlib.HeaderedEvent{} response.Membership[eventID] = nil
continue continue
} }
@ -275,15 +287,15 @@ func (r *Queryer) QueryMembershipAtEvent(
return fmt.Errorf("unable to get memberships at state: %w", err) return fmt.Errorf("unable to get memberships at state: %w", err)
} }
res := make([]*gomatrixserverlib.HeaderedEvent, 0, len(memberships)) // Iterate over all membership events we got. Given we only query the membership for
// one user and assuming this user only ever has one membership event associated to
// a given event, overwrite any other existing membership events.
for i := range memberships { for i := range memberships {
ev := memberships[i] ev := memberships[i]
if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(request.UserID) { if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(request.UserID) {
res = append(res, ev.Headered(info.RoomVersion)) response.Membership[eventID] = ev.Event.Headered(info.RoomVersion)
} }
} }
response.Memberships[eventID] = res
} }
return nil return nil

View file

@ -175,4 +175,11 @@ type Database interface {
GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error) GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error)
PurgeRoom(ctx context.Context, roomID string) error PurgeRoom(ctx context.Context, roomID string) error
UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error
// GetMembershipForHistoryVisibility queries the membership events for the given eventIDs.
// Returns a map from (input) eventID -> membership event. If no membership event is found, returns an empty event, resulting in
// a membership of "leave" when calculating history visibility.
GetMembershipForHistoryVisibility(
ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string,
) (map[string]*gomatrixserverlib.HeaderedEvent, error)
} }

View file

@ -69,6 +69,9 @@ CREATE TABLE IF NOT EXISTS roomserver_events (
auth_event_nids BIGINT[] NOT NULL, auth_event_nids BIGINT[] NOT NULL,
is_rejected BOOLEAN NOT NULL DEFAULT FALSE is_rejected BOOLEAN NOT NULL DEFAULT FALSE
); );
-- Create an index which helps in resolving membership events (event_type_nid = 5) - (used for history visibility)
CREATE INDEX IF NOT EXISTS roomserver_events_memberships_idx ON roomserver_events (room_nid, event_state_key_nid) WHERE (event_type_nid = 5);
` `
const insertEventSQL = "" + const insertEventSQL = "" +

View file

@ -21,10 +21,10 @@ import (
"fmt" "fmt"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
@ -99,10 +99,26 @@ const bulkSelectStateForHistoryVisibilitySQL = `
AND (event_type_nid = 7 OR event_state_key LIKE '%:' || $2); AND (event_type_nid = 7 OR event_state_key LIKE '%:' || $2);
` `
// bulkSelectMembershipForHistoryVisibilitySQL is an optimization to get membership events for a specific user for defined set of events.
// Returns the event_id of the event we want the membership event for, the event_id of the membership event and the membership event JSON.
const bulkSelectMembershipForHistoryVisibilitySQL = `
SELECT re.event_id, re2.event_id, rej.event_json
FROM roomserver_events re
LEFT JOIN roomserver_state_snapshots rss on re.state_snapshot_nid = rss.state_snapshot_nid
CROSS JOIN unnest(rss.state_block_nids) AS blocks(block_nid)
LEFT JOIN roomserver_state_block rsb ON rsb.state_block_nid = blocks.block_nid
CROSS JOIN unnest(rsb.event_nids) AS rsb2(event_nid)
JOIN roomserver_events re2 ON re2.room_nid = $3 AND re2.event_type_nid = 5 AND re2.event_nid = rsb2.event_nid AND re2.event_state_key_nid = $1
LEFT JOIN roomserver_event_json rej ON rej.event_nid = re2.event_nid
WHERE re.event_id = ANY($2)
`
type stateSnapshotStatements struct { type stateSnapshotStatements struct {
insertStateStmt *sql.Stmt insertStateStmt *sql.Stmt
bulkSelectStateBlockNIDsStmt *sql.Stmt bulkSelectStateBlockNIDsStmt *sql.Stmt
bulkSelectStateForHistoryVisibilityStmt *sql.Stmt bulkSelectStateForHistoryVisibilityStmt *sql.Stmt
bulktSelectMembershipForHistoryVisibilityStmt *sql.Stmt
} }
func CreateStateSnapshotTable(db *sql.DB) error { func CreateStateSnapshotTable(db *sql.DB) error {
@ -110,13 +126,14 @@ func CreateStateSnapshotTable(db *sql.DB) error {
return err return err
} }
func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { func PrepareStateSnapshotTable(db *sql.DB) (*stateSnapshotStatements, error) {
s := &stateSnapshotStatements{} s := &stateSnapshotStatements{}
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertStateStmt, insertStateSQL}, {&s.insertStateStmt, insertStateSQL},
{&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
{&s.bulkSelectStateForHistoryVisibilityStmt, bulkSelectStateForHistoryVisibilitySQL}, {&s.bulkSelectStateForHistoryVisibilityStmt, bulkSelectStateForHistoryVisibilitySQL},
{&s.bulktSelectMembershipForHistoryVisibilityStmt, bulkSelectMembershipForHistoryVisibilitySQL},
}.Prepare(db) }.Prepare(db)
} }
@ -185,3 +202,45 @@ func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility(
} }
return results, rows.Err() return results, rows.Err()
} }
func (s *stateSnapshotStatements) BulkSelectMembershipForHistoryVisibility(
ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string,
) (map[string]*gomatrixserverlib.HeaderedEvent, error) {
stmt := sqlutil.TxStmt(txn, s.bulktSelectMembershipForHistoryVisibilityStmt)
rows, err := stmt.QueryContext(ctx, userNID, pq.Array(eventIDs), roomInfo.RoomNID)
if err != nil {
return nil, err
}
defer rows.Close() // nolint: errcheck
result := make(map[string]*gomatrixserverlib.HeaderedEvent, len(eventIDs))
var evJson []byte
var eventID string
var membershipEventID string
knownEvents := make(map[string]*gomatrixserverlib.HeaderedEvent, len(eventIDs))
for rows.Next() {
if err = rows.Scan(&eventID, &membershipEventID, &evJson); err != nil {
return nil, err
}
if len(evJson) == 0 {
result[eventID] = &gomatrixserverlib.HeaderedEvent{}
continue
}
// If we already know this event, don't try to marshal the json again
if ev, ok := knownEvents[membershipEventID]; ok {
result[eventID] = ev
continue
}
event, err := gomatrixserverlib.NewEventFromTrustedJSON(evJson, false, roomInfo.RoomVersion)
if err != nil {
result[eventID] = &gomatrixserverlib.HeaderedEvent{}
// not fatal
continue
}
he := event.Headered(roomInfo.RoomVersion)
result[eventID] = he
knownEvents[membershipEventID] = he
}
return result, rows.Err()
}

View file

@ -51,6 +51,12 @@ func (d *Database) SupportsConcurrentRoomInputs() bool {
return true return true
} }
func (d *Database) GetMembershipForHistoryVisibility(
ctx context.Context, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string,
) (map[string]*gomatrixserverlib.HeaderedEvent, error) {
return d.StateSnapshotTable.BulkSelectMembershipForHistoryVisibility(ctx, nil, userNID, roomInfo, eventIDs...)
}
func (d *Database) EventTypeNIDs( func (d *Database) EventTypeNIDs(
ctx context.Context, eventTypes []string, ctx context.Context, eventTypes []string,
) (map[string]types.EventTypeNID, error) { ) (map[string]types.EventTypeNID, error) {

View file

@ -26,6 +26,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -152,6 +153,10 @@ func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility(
return nil, tables.OptimisationNotSupportedError return nil, tables.OptimisationNotSupportedError
} }
func (s *stateSnapshotStatements) BulkSelectMembershipForHistoryVisibility(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string) (map[string]*gomatrixserverlib.HeaderedEvent, error) {
return nil, tables.OptimisationNotSupportedError
}
func (s *stateSnapshotStatements) selectStateBlockNIDsForRoomNID( func (s *stateSnapshotStatements) selectStateBlockNIDsForRoomNID(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
) ([]types.StateBlockNID, error) { ) ([]types.StateBlockNID, error) {

View file

@ -91,6 +91,10 @@ type StateSnapshot interface {
// which users are in a room faster than having to load the entire room state. In the // which users are in a room faster than having to load the entire room state. In the
// case of SQLite, this will return tables.OptimisationNotSupportedError. // case of SQLite, this will return tables.OptimisationNotSupportedError.
BulkSelectStateForHistoryVisibility(ctx context.Context, txn *sql.Tx, stateSnapshotNID types.StateSnapshotNID, domain string) ([]types.EventNID, error) BulkSelectStateForHistoryVisibility(ctx context.Context, txn *sql.Tx, stateSnapshotNID types.StateSnapshotNID, domain string) ([]types.EventNID, error)
BulkSelectMembershipForHistoryVisibility(
ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string,
) (map[string]*gomatrixserverlib.HeaderedEvent, error)
} }
type StateBlock interface { type StateBlock interface {

View file

@ -29,6 +29,8 @@ func mustCreateStateSnapshotTable(t *testing.T, dbType test.DBType) (tab tables.
assert.NoError(t, err) assert.NoError(t, err)
err = postgres.CreateEventsTable(db) err = postgres.CreateEventsTable(db)
assert.NoError(t, err) assert.NoError(t, err)
err = postgres.CreateEventJSONTable(db)
assert.NoError(t, err)
err = postgres.CreateStateBlockTable(db) err = postgres.CreateStateBlockTable(db)
assert.NoError(t, err) assert.NoError(t, err)
// ... and then the snapshot table itself // ... and then the snapshot table itself

View file

@ -121,10 +121,7 @@ func ApplyHistoryVisibilityFilter(
// Get the mapping from eventID -> eventVisibility // Get the mapping from eventID -> eventVisibility
eventsFiltered := make([]*gomatrixserverlib.HeaderedEvent, 0, len(events)) eventsFiltered := make([]*gomatrixserverlib.HeaderedEvent, 0, len(events))
visibilities, err := visibilityForEvents(ctx, rsAPI, events, userID, events[0].RoomID()) visibilities := visibilityForEvents(ctx, rsAPI, events, userID, events[0].RoomID())
if err != nil {
return eventsFiltered, err
}
for _, ev := range events { for _, ev := range events {
evVis := visibilities[ev.EventID()] evVis := visibilities[ev.EventID()]
evVis.membershipCurrent = membershipCurrent evVis.membershipCurrent = membershipCurrent
@ -175,7 +172,7 @@ func visibilityForEvents(
rsAPI api.SyncRoomserverAPI, rsAPI api.SyncRoomserverAPI,
events []*gomatrixserverlib.HeaderedEvent, events []*gomatrixserverlib.HeaderedEvent,
userID, roomID string, userID, roomID string,
) (map[string]eventVisibility, error) { ) map[string]eventVisibility {
eventIDs := make([]string, len(events)) eventIDs := make([]string, len(events))
for i := range events { for i := range events {
eventIDs[i] = events[i].EventID() eventIDs[i] = events[i].EventID()
@ -185,6 +182,7 @@ func visibilityForEvents(
// get the membership events for all eventIDs // get the membership events for all eventIDs
membershipResp := &api.QueryMembershipAtEventResponse{} membershipResp := &api.QueryMembershipAtEventResponse{}
err := rsAPI.QueryMembershipAtEvent(ctx, &api.QueryMembershipAtEventRequest{ err := rsAPI.QueryMembershipAtEvent(ctx, &api.QueryMembershipAtEventRequest{
RoomID: roomID, RoomID: roomID,
EventIDs: eventIDs, EventIDs: eventIDs,
@ -201,19 +199,20 @@ func visibilityForEvents(
membershipAtEvent: gomatrixserverlib.Leave, // default to leave, to not expose events by accident membershipAtEvent: gomatrixserverlib.Leave, // default to leave, to not expose events by accident
visibility: event.Visibility, visibility: event.Visibility,
} }
membershipEvs, ok := membershipResp.Memberships[eventID] ev, ok := membershipResp.Membership[eventID]
if !ok { if !ok || ev == nil {
result[eventID] = vis result[eventID] = vis
continue continue
} }
for _, ev := range membershipEvs {
membership, err := ev.Membership() membership, err := ev.Membership()
if err != nil { if err != nil {
return result, err result[eventID] = vis
continue
} }
vis.membershipAtEvent = membership vis.membershipAtEvent = membership
}
result[eventID] = vis result[eventID] = vis
} }
return result, nil return result
} }

View file

@ -67,6 +67,8 @@ func Context(
errMsg = "unable to parse filter" errMsg = "unable to parse filter"
case *strconv.NumError: case *strconv.NumError:
errMsg = "unable to parse limit" errMsg = "unable to parse limit"
default:
errMsg = err.Error()
} }
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
@ -167,7 +169,18 @@ func Context(
eventsBeforeClient := gomatrixserverlib.HeaderedToClientEvents(eventsBeforeFiltered, gomatrixserverlib.FormatAll) eventsBeforeClient := gomatrixserverlib.HeaderedToClientEvents(eventsBeforeFiltered, gomatrixserverlib.FormatAll)
eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfterFiltered, gomatrixserverlib.FormatAll) eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfterFiltered, gomatrixserverlib.FormatAll)
newState := applyLazyLoadMembers(device, filter, eventsAfterClient, eventsBeforeClient, state, lazyLoadCache)
newState := state
if filter.LazyLoadMembers {
allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...)
allEvents = append(allEvents, &requestedEvent)
evs := gomatrixserverlib.HeaderedToClientEvents(allEvents, gomatrixserverlib.FormatAll)
newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache)
if err != nil {
logrus.WithError(err).Error("unable to load membership events")
return jsonerror.InternalServerError()
}
}
ev := gomatrixserverlib.HeaderedToClientEvent(&requestedEvent, gomatrixserverlib.FormatAll) ev := gomatrixserverlib.HeaderedToClientEvent(&requestedEvent, gomatrixserverlib.FormatAll)
response := ContextRespsonse{ response := ContextRespsonse{
@ -244,41 +257,43 @@ func getStartEnd(ctx context.Context, snapshot storage.DatabaseTransaction, star
} }
func applyLazyLoadMembers( func applyLazyLoadMembers(
ctx context.Context,
device *userapi.Device, device *userapi.Device,
filter *gomatrixserverlib.RoomEventFilter, snapshot storage.DatabaseTransaction,
eventsAfter, eventsBefore []gomatrixserverlib.ClientEvent, roomID string,
state []*gomatrixserverlib.HeaderedEvent, events []gomatrixserverlib.ClientEvent,
lazyLoadCache caching.LazyLoadCache, lazyLoadCache caching.LazyLoadCache,
) []*gomatrixserverlib.HeaderedEvent { ) ([]*gomatrixserverlib.HeaderedEvent, error) {
if filter == nil || !filter.LazyLoadMembers { eventSenders := make(map[string]struct{})
return state
}
allEvents := append(eventsBefore, eventsAfter...)
x := make(map[string]struct{})
// get members who actually send an event // get members who actually send an event
for _, e := range allEvents { for _, e := range events {
// Don't add membership events the client should already know about // Don't add membership events the client should already know about
if _, cached := lazyLoadCache.IsLazyLoadedUserCached(device, e.RoomID, e.Sender); cached { if _, cached := lazyLoadCache.IsLazyLoadedUserCached(device, e.RoomID, e.Sender); cached {
continue continue
} }
x[e.Sender] = struct{}{} eventSenders[e.Sender] = struct{}{}
} }
newState := []*gomatrixserverlib.HeaderedEvent{} wantUsers := make([]string, 0, len(eventSenders))
membershipEvents := []*gomatrixserverlib.HeaderedEvent{} for userID := range eventSenders {
for _, event := range state { wantUsers = append(wantUsers, userID)
if event.Type() != gomatrixserverlib.MRoomMember {
newState = append(newState, event)
} else {
// did the user send an event?
if _, ok := x[event.Sender()]; ok {
membershipEvents = append(membershipEvents, event)
lazyLoadCache.StoreLazyLoadedUser(device, event.RoomID(), event.Sender(), event.EventID())
} }
// Query missing membership events
filter := gomatrixserverlib.DefaultStateFilter()
filter.Senders = &wantUsers
filter.Types = &[]string{gomatrixserverlib.MRoomMember}
memberships, err := snapshot.GetStateEventsForRoom(ctx, roomID, &filter)
if err != nil {
return nil, err
} }
// cache the membership events
for _, membership := range memberships {
lazyLoadCache.StoreLazyLoadedUser(device, roomID, *membership.StateKey(), membership.EventID())
} }
// Add the membershipEvents to the end of the list, to make Sytest happy
return append(newState, membershipEvents...) return memberships, nil
} }
func parseRoomEventFilter(req *http.Request) (*gomatrixserverlib.RoomEventFilter, error) { func parseRoomEventFilter(req *http.Request) (*gomatrixserverlib.RoomEventFilter, error) {

View file

@ -64,6 +64,7 @@ type messagesResp struct {
// OnIncomingMessagesRequest implements the /messages endpoint from the // OnIncomingMessagesRequest implements the /messages endpoint from the
// client-server API. // client-server API.
// See: https://matrix.org/docs/spec/client_server/latest.html#get-matrix-client-r0-rooms-roomid-messages // See: https://matrix.org/docs/spec/client_server/latest.html#get-matrix-client-r0-rooms-roomid-messages
// nolint:gocyclo
func OnIncomingMessagesRequest( func OnIncomingMessagesRequest(
req *http.Request, db storage.Database, roomID string, device *userapi.Device, req *http.Request, db storage.Database, roomID string, device *userapi.Device,
rsAPI api.SyncRoomserverAPI, rsAPI api.SyncRoomserverAPI,
@ -246,7 +247,14 @@ func OnIncomingMessagesRequest(
Start: start.String(), Start: start.String(),
End: end.String(), End: end.String(),
} }
res.applyLazyLoadMembers(req.Context(), snapshot, roomID, device, filter.LazyLoadMembers, lazyLoadCache) if filter.LazyLoadMembers {
membershipEvents, err := applyLazyLoadMembers(req.Context(), device, snapshot, roomID, clientEvents, lazyLoadCache)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("failed to apply lazy loading")
return jsonerror.InternalServerError()
}
res.State = append(res.State, gomatrixserverlib.HeaderedToClientEvents(membershipEvents, gomatrixserverlib.FormatAll)...)
}
// If we didn't return any events, set the end to an empty string, so it will be omitted // If we didn't return any events, set the end to an empty string, so it will be omitted
// in the response JSON. // in the response JSON.
@ -265,40 +273,6 @@ func OnIncomingMessagesRequest(
} }
} }
// applyLazyLoadMembers loads membership events for users returned in Chunk, if the filter has
// LazyLoadMembers enabled.
func (m *messagesResp) applyLazyLoadMembers(
ctx context.Context,
db storage.DatabaseTransaction,
roomID string,
device *userapi.Device,
lazyLoad bool,
lazyLoadCache caching.LazyLoadCache,
) {
if !lazyLoad {
return
}
membershipToUser := make(map[string]*gomatrixserverlib.HeaderedEvent)
for _, evt := range m.Chunk {
// Don't add membership events the client should already know about
if _, cached := lazyLoadCache.IsLazyLoadedUserCached(device, roomID, evt.Sender); cached {
continue
}
membership, err := db.GetStateEvent(ctx, roomID, gomatrixserverlib.MRoomMember, evt.Sender)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("failed to get membership event for user")
continue
}
if membership != nil {
membershipToUser[evt.Sender] = membership
lazyLoadCache.StoreLazyLoadedUser(device, roomID, evt.Sender, membership.EventID())
}
}
for _, evt := range membershipToUser {
m.State = append(m.State, gomatrixserverlib.HeaderedToClientEvent(evt, gomatrixserverlib.FormatAll))
}
}
func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (resp api.QueryMembershipForUserResponse, err error) { func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (resp api.QueryMembershipForUserResponse, err error) {
req := api.QueryMembershipForUserRequest{ req := api.QueryMembershipForUserRequest{
RoomID: roomID, RoomID: roomID,

View file

@ -46,7 +46,7 @@ type DatabaseTransaction interface {
RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error)
MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error)
GetRoomSummary(ctx context.Context, roomID, userID string) (summary *types.Summary, err error) GetRoomSummary(ctx context.Context, roomID, userID string) (summary *types.Summary, err error)
RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) RecentEvents(ctx context.Context, roomIDs []string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) (map[string]types.RecentEvents, error)
GetBackwardTopologyPos(ctx context.Context, events []*gomatrixserverlib.HeaderedEvent) (types.TopologyToken, error) GetBackwardTopologyPos(ctx context.Context, events []*gomatrixserverlib.HeaderedEvent) (types.TopologyToken, error)
PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error)
InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error)

View file

@ -275,6 +275,15 @@ func (s *currentRoomStateStatements) SelectCurrentState(
) ([]*gomatrixserverlib.HeaderedEvent, error) { ) ([]*gomatrixserverlib.HeaderedEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt) stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt)
senders, notSenders := getSendersStateFilterFilter(stateFilter) senders, notSenders := getSendersStateFilterFilter(stateFilter)
// We're going to query members later, so remove them from this request
if stateFilter.LazyLoadMembers && !stateFilter.IncludeRedundantMembers {
notTypes := &[]string{gomatrixserverlib.MRoomMember}
if stateFilter.NotTypes != nil {
*stateFilter.NotTypes = append(*stateFilter.NotTypes, gomatrixserverlib.MRoomMember)
} else {
stateFilter.NotTypes = notTypes
}
}
rows, err := stmt.QueryContext(ctx, roomID, rows, err := stmt.QueryContext(ctx, roomID,
pq.StringArray(senders), pq.StringArray(senders),
pq.StringArray(notSenders), pq.StringArray(notSenders),

View file

@ -0,0 +1,29 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"context"
"database/sql"
"fmt"
)
func UpRenameOutputRoomEventsIndex(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `ALTER TABLE syncapi_output_room_events RENAME CONSTRAINT syncapi_event_id_idx TO syncapi_output_room_event_id_idx;`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}

View file

@ -19,18 +19,17 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt"
"sort" "sort"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/lib/pq"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
const outputRoomEventsSchema = ` const outputRoomEventsSchema = `
@ -44,7 +43,7 @@ CREATE TABLE IF NOT EXISTS syncapi_output_room_events (
-- This isn't a problem for us since we just want to order by this field. -- This isn't a problem for us since we just want to order by this field.
id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_stream_id'), id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_stream_id'),
-- The event ID for the event -- The event ID for the event
event_id TEXT NOT NULL CONSTRAINT syncapi_event_id_idx UNIQUE, event_id TEXT NOT NULL CONSTRAINT syncapi_output_room_event_id_idx UNIQUE,
-- The 'room_id' key for the event. -- The 'room_id' key for the event.
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
-- The headered JSON for the event, containing potentially additional metadata such as -- The headered JSON for the event, containing potentially additional metadata such as
@ -79,13 +78,16 @@ CREATE INDEX IF NOT EXISTS syncapi_output_room_events_room_id_idx ON syncapi_out
CREATE INDEX IF NOT EXISTS syncapi_output_room_events_exclude_from_sync_idx ON syncapi_output_room_events (exclude_from_sync); CREATE INDEX IF NOT EXISTS syncapi_output_room_events_exclude_from_sync_idx ON syncapi_output_room_events (exclude_from_sync);
CREATE INDEX IF NOT EXISTS syncapi_output_room_events_add_state_ids_idx ON syncapi_output_room_events ((add_state_ids IS NOT NULL)); CREATE INDEX IF NOT EXISTS syncapi_output_room_events_add_state_ids_idx ON syncapi_output_room_events ((add_state_ids IS NOT NULL));
CREATE INDEX IF NOT EXISTS syncapi_output_room_events_remove_state_ids_idx ON syncapi_output_room_events ((remove_state_ids IS NOT NULL)); CREATE INDEX IF NOT EXISTS syncapi_output_room_events_remove_state_ids_idx ON syncapi_output_room_events ((remove_state_ids IS NOT NULL));
CREATE INDEX IF NOT EXISTS syncapi_output_room_events_recent_events_idx ON syncapi_output_room_events (room_id, exclude_from_sync, id, sender, type);
` `
const insertEventSQL = "" + const insertEventSQL = "" +
"INSERT INTO syncapi_output_room_events (" + "INSERT INTO syncapi_output_room_events (" +
"room_id, event_id, headered_event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync, history_visibility" + "room_id, event_id, headered_event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync, history_visibility" +
") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) " + ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) " +
"ON CONFLICT ON CONSTRAINT syncapi_event_id_idx DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $11) " + "ON CONFLICT ON CONSTRAINT syncapi_output_room_event_id_idx DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $11) " +
"RETURNING id" "RETURNING id"
const selectEventsSQL = "" + const selectEventsSQL = "" +
@ -109,14 +111,29 @@ const selectRecentEventsSQL = "" +
" AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" + " AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" +
" ORDER BY id DESC LIMIT $8" " ORDER BY id DESC LIMIT $8"
const selectRecentEventsForSyncSQL = "" + // selectRecentEventsForSyncSQL contains an optimization to get the recent events for a list of rooms, using a LATERAL JOIN
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" + // The sub select inside LATERAL () is executed for all room_ids it gets as a parameter $1
" WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" + const selectRecentEventsForSyncSQL = `
" AND ( $4::text[] IS NULL OR sender = ANY($4) )" + WITH room_ids AS (
" AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" + SELECT unnest($1::text[]) AS room_id
" AND ( $6::text[] IS NULL OR type LIKE ANY($6) )" + )
" AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" + SELECT x.*
" ORDER BY id DESC LIMIT $8" FROM room_ids,
LATERAL (
SELECT room_id, event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility
FROM syncapi_output_room_events recent_events
WHERE
recent_events.room_id = room_ids.room_id
AND recent_events.exclude_from_sync = FALSE
AND id > $2 AND id <= $3
AND ( $4::text[] IS NULL OR sender = ANY($4) )
AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )
AND ( $6::text[] IS NULL OR type LIKE ANY($6) )
AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )
ORDER BY recent_events.id DESC
LIMIT $8
) AS x
`
const selectEarlyEventsSQL = "" + const selectEarlyEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" +
@ -207,12 +224,30 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
return nil, err return nil, err
} }
migrationName := "syncapi: rename dupe index (output_room_events)"
var cName string
err = db.QueryRowContext(context.Background(), "select constraint_name from information_schema.table_constraints where table_name = 'syncapi_output_room_events' AND constraint_name = 'syncapi_event_id_idx'").Scan(&cName)
switch err {
case sql.ErrNoRows: // migration was already executed, as the index was renamed
if err = sqlutil.InsertMigration(context.Background(), db, migrationName); err != nil {
return nil, fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err)
}
case nil:
default:
return nil, err
}
m := sqlutil.NewMigrator(db) m := sqlutil.NewMigrator(db)
m.AddMigrations( m.AddMigrations(
sqlutil.Migration{ sqlutil.Migration{
Version: "syncapi: add history visibility column (output_room_events)", Version: "syncapi: add history visibility column (output_room_events)",
Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents, Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents,
}, },
sqlutil.Migration{
Version: migrationName,
Up: deltas.UpRenameOutputRoomEventsIndex,
},
) )
err = m.Up(context.Background()) err = m.Up(context.Background())
if err != nil { if err != nil {
@ -398,9 +433,9 @@ func (s *outputRoomEventsStatements) InsertEvent(
// from sync. // from sync.
func (s *outputRoomEventsStatements) SelectRecentEvents( func (s *outputRoomEventsStatements) SelectRecentEvents(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, roomIDs []string, ra types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
chronologicalOrder bool, onlySyncEvents bool, chronologicalOrder bool, onlySyncEvents bool,
) ([]types.StreamEvent, bool, error) { ) (map[string]types.RecentEvents, error) {
var stmt *sql.Stmt var stmt *sql.Stmt
if onlySyncEvents { if onlySyncEvents {
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsForSyncStmt) stmt = sqlutil.TxStmt(txn, s.selectRecentEventsForSyncStmt)
@ -408,8 +443,9 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt) stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt)
} }
senders, notSenders := getSendersRoomEventFilter(eventFilter) senders, notSenders := getSendersRoomEventFilter(eventFilter)
rows, err := stmt.QueryContext( rows, err := stmt.QueryContext(
ctx, roomID, r.Low(), r.High(), ctx, pq.StringArray(roomIDs), ra.Low(), ra.High(),
pq.StringArray(senders), pq.StringArray(senders),
pq.StringArray(notSenders), pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)),
@ -417,34 +453,80 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
eventFilter.Limit+1, eventFilter.Limit+1,
) )
if err != nil { if err != nil {
return nil, false, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed")
events, err := rowsToStreamEvents(rows)
if err != nil { result := make(map[string]types.RecentEvents)
return nil, false, err
for rows.Next() {
var (
roomID string
eventID string
streamPos types.StreamPosition
eventBytes []byte
excludeFromSync bool
sessionID *int64
txnID *string
transactionID *api.TransactionID
historyVisibility gomatrixserverlib.HistoryVisibility
)
if err := rows.Scan(&roomID, &eventID, &streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID, &historyVisibility); err != nil {
return nil, err
} }
if chronologicalOrder { // TODO: Handle redacted events
// The events need to be returned from oldest to latest, which isn't var ev gomatrixserverlib.HeaderedEvent
// necessary the way the SQL query returns them, so a sort is necessary to if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil {
// ensure the events are in the right order in the slice. return nil, err
sort.SliceStable(events, func(i int, j int) bool {
return events[i].StreamPosition < events[j].StreamPosition
})
} }
// we queried for 1 more than the limit, so if we returned one more mark limited=true
limited := false if sessionID != nil && txnID != nil {
if len(events) > eventFilter.Limit { transactionID = &api.TransactionID{
limited = true SessionID: *sessionID,
// re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last. TransactionID: *txnID,
if chronologicalOrder {
events = events[1:]
} else {
events = events[:len(events)-1]
} }
} }
return events, limited, nil r := result[roomID]
ev.Visibility = historyVisibility
r.Events = append(r.Events, types.StreamEvent{
HeaderedEvent: &ev,
StreamPosition: streamPos,
TransactionID: transactionID,
ExcludeFromSync: excludeFromSync,
})
result[roomID] = r
}
if chronologicalOrder {
for roomID, evs := range result {
// The events need to be returned from oldest to latest, which isn't
// necessary the way the SQL query returns them, so a sort is necessary to
// ensure the events are in the right order in the slice.
sort.SliceStable(evs.Events, func(i int, j int) bool {
return evs.Events[i].StreamPosition < evs.Events[j].StreamPosition
})
if len(evs.Events) > eventFilter.Limit {
evs.Limited = true
evs.Events = evs.Events[1:]
}
result[roomID] = evs
}
} else {
for roomID, evs := range result {
if len(evs.Events) > eventFilter.Limit {
evs.Limited = true
evs.Events = evs.Events[:len(evs.Events)-1]
}
result[roomID] = evs
}
}
return result, rows.Err()
} }
// selectEarlyEvents returns the earliest events in the given room, starting // selectEarlyEvents returns the earliest events in the given room, starting

View file

@ -151,8 +151,8 @@ func (d *DatabaseTransaction) GetRoomSummary(ctx context.Context, roomID, userID
return summary, nil return summary, nil
} }
func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) { func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomIDs []string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) (map[string]types.RecentEvents, error) {
return d.OutputEvents.SelectRecentEvents(ctx, d.txn, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents) return d.OutputEvents.SelectRecentEvents(ctx, d.txn, roomIDs, r, eventFilter, chronologicalOrder, onlySyncEvents)
} }
func (d *DatabaseTransaction) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) { func (d *DatabaseTransaction) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) {
@ -370,20 +370,26 @@ func (d *DatabaseTransaction) GetStateDeltas(
} }
// get all the state events ever (i.e. for all available rooms) between these two positions // get all the state events ever (i.e. for all available rooms) between these two positions
stateNeededFiltered, eventMapFiltered, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs) stateFiltered := state
// avoid hitting the database if the result would be the same as above
if !isStatefilterEmpty(stateFilter) {
var stateNeededFiltered map[string]map[string]bool
var eventMapFiltered map[string]types.StreamEvent
stateNeededFiltered, eventMapFiltered, err = d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil, nil return nil, nil, nil
} }
return nil, nil, err return nil, nil, err
} }
stateFiltered, err := d.fetchStateEvents(ctx, d.txn, stateNeededFiltered, eventMapFiltered) stateFiltered, err = d.fetchStateEvents(ctx, d.txn, stateNeededFiltered, eventMapFiltered)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil, nil return nil, nil, nil
} }
return nil, nil, err return nil, nil, err
} }
}
// find out which rooms this user is peeking, if any. // find out which rooms this user is peeking, if any.
// We do this before joins so any peeks get overwritten // We do this before joins so any peeks get overwritten
@ -701,6 +707,28 @@ func (d *DatabaseTransaction) MaxStreamPositionForRelations(ctx context.Context)
return types.StreamPosition(id), err return types.StreamPosition(id), err
} }
func isStatefilterEmpty(filter *gomatrixserverlib.StateFilter) bool {
if filter == nil {
return true
}
switch {
case filter.NotTypes != nil && len(*filter.NotTypes) > 0:
return false
case filter.Types != nil && len(*filter.Types) > 0:
return false
case filter.Senders != nil && len(*filter.Senders) > 0:
return false
case filter.NotSenders != nil && len(*filter.NotSenders) > 0:
return false
case filter.NotRooms != nil && len(*filter.NotRooms) > 0:
return false
case filter.ContainsURL != nil:
return false
default:
return true
}
}
func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) ( func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (
events []types.StreamEvent, prevBatch, nextBatch string, err error, events []types.StreamEvent, prevBatch, nextBatch string, err error,
) { ) {

View file

@ -0,0 +1,72 @@
package shared
import (
"testing"
"github.com/matrix-org/gomatrixserverlib"
)
func Test_isStatefilterEmpty(t *testing.T) {
filterSet := []string{"a"}
boolValue := false
tests := []struct {
name string
filter *gomatrixserverlib.StateFilter
want bool
}{
{
name: "nil filter is empty",
filter: nil,
want: true,
},
{
name: "Empty filter is empty",
filter: &gomatrixserverlib.StateFilter{},
want: true,
},
{
name: "NotTypes is set",
filter: &gomatrixserverlib.StateFilter{
NotTypes: &filterSet,
},
},
{
name: "Types is set",
filter: &gomatrixserverlib.StateFilter{
Types: &filterSet,
},
},
{
name: "Senders is set",
filter: &gomatrixserverlib.StateFilter{
Senders: &filterSet,
},
},
{
name: "NotSenders is set",
filter: &gomatrixserverlib.StateFilter{
NotSenders: &filterSet,
},
},
{
name: "NotRooms is set",
filter: &gomatrixserverlib.StateFilter{
NotRooms: &filterSet,
},
},
{
name: "ContainsURL is set",
filter: &gomatrixserverlib.StateFilter{
ContainsURL: &boolValue,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isStatefilterEmpty(tt.filter); got != tt.want {
t.Errorf("isStatefilterEmpty() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -105,7 +105,6 @@ func (s *accountDataStatements) SelectAccountDataInRange(
filter.Senders, filter.NotSenders, filter.Senders, filter.NotSenders,
filter.Types, filter.NotTypes, filter.Types, filter.NotTypes,
[]string{}, nil, filter.Limit, FilterOrderAsc) []string{}, nil, filter.Limit, FilterOrderAsc)
if err != nil { if err != nil {
return return
} }

View file

@ -267,6 +267,15 @@ func (s *currentRoomStateStatements) SelectCurrentState(
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
excludeEventIDs []string, excludeEventIDs []string,
) ([]*gomatrixserverlib.HeaderedEvent, error) { ) ([]*gomatrixserverlib.HeaderedEvent, error) {
// We're going to query members later, so remove them from this request
if stateFilter.LazyLoadMembers && !stateFilter.IncludeRedundantMembers {
notTypes := &[]string{gomatrixserverlib.MRoomMember}
if stateFilter.NotTypes != nil {
*stateFilter.NotTypes = append(*stateFilter.NotTypes, gomatrixserverlib.MRoomMember)
} else {
stateFilter.NotTypes = notTypes
}
}
stmt, params, err := prepareWithFilters( stmt, params, err := prepareWithFilters(
s.db, txn, selectCurrentStateSQL, s.db, txn, selectCurrentStateSQL,
[]interface{}{ []interface{}{

View file

@ -368,9 +368,9 @@ func (s *outputRoomEventsStatements) InsertEvent(
func (s *outputRoomEventsStatements) SelectRecentEvents( func (s *outputRoomEventsStatements) SelectRecentEvents(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, roomIDs []string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
chronologicalOrder bool, onlySyncEvents bool, chronologicalOrder bool, onlySyncEvents bool,
) ([]types.StreamEvent, bool, error) { ) (map[string]types.RecentEvents, error) {
var query string var query string
if onlySyncEvents { if onlySyncEvents {
query = selectRecentEventsForSyncSQL query = selectRecentEventsForSyncSQL
@ -378,6 +378,8 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
query = selectRecentEventsSQL query = selectRecentEventsSQL
} }
result := make(map[string]types.RecentEvents, len(roomIDs))
for _, roomID := range roomIDs {
stmt, params, err := prepareWithFilters( stmt, params, err := prepareWithFilters(
s.db, txn, query, s.db, txn, query,
[]interface{}{ []interface{}{
@ -388,18 +390,18 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
nil, eventFilter.ContainsURL, eventFilter.Limit+1, FilterOrderDesc, nil, eventFilter.ContainsURL, eventFilter.Limit+1, FilterOrderDesc,
) )
if err != nil { if err != nil {
return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err) return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
} }
defer internal.CloseAndLogIfError(ctx, stmt, "selectRecentEvents: stmt.close() failed") defer internal.CloseAndLogIfError(ctx, stmt, "selectRecentEvents: stmt.close() failed")
rows, err := stmt.QueryContext(ctx, params...) rows, err := stmt.QueryContext(ctx, params...)
if err != nil { if err != nil {
return nil, false, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed")
events, err := rowsToStreamEvents(rows) events, err := rowsToStreamEvents(rows)
if err != nil { if err != nil {
return nil, false, err return nil, err
} }
if chronologicalOrder { if chronologicalOrder {
// The events need to be returned from oldest to latest, which isn't // The events need to be returned from oldest to latest, which isn't
@ -409,10 +411,10 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
return events[i].StreamPosition < events[j].StreamPosition return events[i].StreamPosition < events[j].StreamPosition
}) })
} }
res := types.RecentEvents{}
// we queried for 1 more than the limit, so if we returned one more mark limited=true // we queried for 1 more than the limit, so if we returned one more mark limited=true
limited := false
if len(events) > eventFilter.Limit { if len(events) > eventFilter.Limit {
limited = true res.Limited = true
// re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last. // re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last.
if chronologicalOrder { if chronologicalOrder {
events = events[1:] events = events[1:]
@ -420,7 +422,11 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
events = events[:len(events)-1] events = events[:len(events)-1]
} }
} }
return events, limited, nil res.Events = events
result[roomID] = res
}
return result, nil
} }
func (s *outputRoomEventsStatements) SelectEarlyEvents( func (s *outputRoomEventsStatements) SelectEarlyEvents(

View file

@ -156,12 +156,12 @@ func TestRecentEventsPDU(t *testing.T) {
tc := testCases[i] tc := testCases[i]
t.Run(tc.Name, func(st *testing.T) { t.Run(tc.Name, func(st *testing.T) {
var filter gomatrixserverlib.RoomEventFilter var filter gomatrixserverlib.RoomEventFilter
var gotEvents []types.StreamEvent var gotEvents map[string]types.RecentEvents
var limited bool var limited bool
filter.Limit = tc.Limit filter.Limit = tc.Limit
WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) { WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
var err error var err error
gotEvents, limited, err = snapshot.RecentEvents(ctx, r.ID, types.Range{ gotEvents, err = snapshot.RecentEvents(ctx, []string{r.ID}, types.Range{
From: tc.From, From: tc.From,
To: tc.To, To: tc.To,
}, &filter, !tc.ReverseOrder, true) }, &filter, !tc.ReverseOrder, true)
@ -169,15 +169,18 @@ func TestRecentEventsPDU(t *testing.T) {
st.Fatalf("failed to do sync: %s", err) st.Fatalf("failed to do sync: %s", err)
} }
}) })
streamEvents := gotEvents[r.ID]
limited = streamEvents.Limited
if limited != tc.WantLimited { if limited != tc.WantLimited {
st.Errorf("got limited=%v want %v", limited, tc.WantLimited) st.Errorf("got limited=%v want %v", limited, tc.WantLimited)
} }
if len(gotEvents) != len(tc.WantEvents) { if len(streamEvents.Events) != len(tc.WantEvents) {
st.Errorf("got %d events, want %d", len(gotEvents), len(tc.WantEvents)) st.Errorf("got %d events, want %d", len(gotEvents), len(tc.WantEvents))
} }
for j := range gotEvents {
if !reflect.DeepEqual(gotEvents[j].JSON(), tc.WantEvents[j].JSON()) { for j := range streamEvents.Events {
st.Errorf("event %d got %s want %s", j, string(gotEvents[j].JSON()), string(tc.WantEvents[j].JSON())) if !reflect.DeepEqual(streamEvents.Events[j].JSON(), tc.WantEvents[j].JSON()) {
st.Errorf("event %d got %s want %s", j, string(streamEvents.Events[j].JSON()), string(tc.WantEvents[j].JSON()))
} }
} }
}) })
@ -923,3 +926,61 @@ func TestRoomSummary(t *testing.T) {
} }
}) })
} }
func TestRecentEvents(t *testing.T) {
alice := test.NewUser(t)
room1 := test.NewRoom(t, alice)
room2 := test.NewRoom(t, alice)
roomIDs := []string{room1.ID, room2.ID}
rooms := map[string]*test.Room{
room1.ID: room1,
room2.ID: room2,
}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
filter := gomatrixserverlib.DefaultRoomEventFilter()
db, close, closeBase := MustCreateDatabase(t, dbType)
t.Cleanup(func() {
close()
closeBase()
})
MustWriteEvents(t, db, room1.Events())
MustWriteEvents(t, db, room2.Events())
transaction, err := db.NewDatabaseTransaction(ctx)
assert.NoError(t, err)
defer transaction.Rollback()
// get all recent events from 0 to 100 (we only created 5 events, so we should get 5 back)
roomEvs, err := transaction.RecentEvents(ctx, roomIDs, types.Range{From: 0, To: 100}, &filter, true, true)
assert.NoError(t, err)
assert.Equal(t, len(roomEvs), 2, "unexpected recent events response")
for _, recentEvents := range roomEvs {
assert.Equal(t, 5, len(recentEvents.Events), "unexpected recent events for room")
}
// update the filter to only return one event
filter.Limit = 1
roomEvs, err = transaction.RecentEvents(ctx, roomIDs, types.Range{From: 0, To: 100}, &filter, true, true)
assert.NoError(t, err)
assert.Equal(t, len(roomEvs), 2, "unexpected recent events response")
for roomID, recentEvents := range roomEvs {
origEvents := rooms[roomID].Events()
assert.Equal(t, true, recentEvents.Limited, "expected events to be limited")
assert.Equal(t, 1, len(recentEvents.Events), "unexpected recent events for room")
assert.Equal(t, origEvents[len(origEvents)-1].EventID(), recentEvents.Events[0].EventID())
}
// not chronologically ordered still returns the events in order (given ORDER BY id DESC)
roomEvs, err = transaction.RecentEvents(ctx, roomIDs, types.Range{From: 0, To: 100}, &filter, false, true)
assert.NoError(t, err)
assert.Equal(t, len(roomEvs), 2, "unexpected recent events response")
for roomID, recentEvents := range roomEvs {
origEvents := rooms[roomID].Events()
assert.Equal(t, true, recentEvents.Limited, "expected events to be limited")
assert.Equal(t, 1, len(recentEvents.Events), "unexpected recent events for room")
assert.Equal(t, origEvents[len(origEvents)-1].EventID(), recentEvents.Events[0].EventID())
}
})
}

View file

@ -13,6 +13,7 @@ import (
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
) )
func newCurrentRoomStateTable(t *testing.T, dbType test.DBType) (tables.CurrentRoomState, *sql.DB, func()) { func newCurrentRoomStateTable(t *testing.T, dbType test.DBType) (tables.CurrentRoomState, *sql.DB, func()) {
@ -79,6 +80,9 @@ func TestCurrentRoomStateTable(t *testing.T) {
return fmt.Errorf("SelectEventsWithEventIDs\nexpected id %q not returned", id) return fmt.Errorf("SelectEventsWithEventIDs\nexpected id %q not returned", id)
} }
} }
testCurrentState(t, ctx, txn, tab, room)
return nil return nil
}) })
if err != nil { if err != nil {
@ -86,3 +90,39 @@ func TestCurrentRoomStateTable(t *testing.T) {
} }
}) })
} }
func testCurrentState(t *testing.T, ctx context.Context, txn *sql.Tx, tab tables.CurrentRoomState, room *test.Room) {
t.Run("test currentState", func(t *testing.T) {
// returns the complete state of the room with a default filter
filter := gomatrixserverlib.DefaultStateFilter()
evs, err := tab.SelectCurrentState(ctx, txn, room.ID, &filter, nil)
if err != nil {
t.Fatal(err)
}
expectCount := 5
if gotCount := len(evs); gotCount != expectCount {
t.Fatalf("expected %d state events, got %d", expectCount, gotCount)
}
// When lazy loading, we expect no membership event, so only 4 events
filter.LazyLoadMembers = true
expectCount = 4
evs, err = tab.SelectCurrentState(ctx, txn, room.ID, &filter, nil)
if err != nil {
t.Fatal(err)
}
if gotCount := len(evs); gotCount != expectCount {
t.Fatalf("expected %d state events, got %d", expectCount, gotCount)
}
// same as above, but with existing NotTypes defined
notTypes := []string{gomatrixserverlib.MRoomMember}
filter.NotTypes = &notTypes
evs, err = tab.SelectCurrentState(ctx, txn, room.ID, &filter, nil)
if err != nil {
t.Fatal(err)
}
if gotCount := len(evs); gotCount != expectCount {
t.Fatalf("expected %d state events, got %d", expectCount, gotCount)
}
})
}

View file

@ -66,7 +66,7 @@ type Events interface {
// SelectRecentEvents returns events between the two stream positions: exclusive of low and inclusive of high. // SelectRecentEvents returns events between the two stream positions: exclusive of low and inclusive of high.
// If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude from sync. // If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude from sync.
// Returns up to `limit` events. Returns `limited=true` if there are more events in this range but we hit the `limit`. // Returns up to `limit` events. Returns `limited=true` if there are more events in this range but we hit the `limit`.
SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomIDs []string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) (map[string]types.RecentEvents, error)
// SelectEarlyEvents returns the earliest events in the given room. // SelectEarlyEvents returns the earliest events in the given room.
SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error) SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error)
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool) ([]types.StreamEvent, error) SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool) ([]types.StreamEvent, error)

View file

@ -82,19 +82,24 @@ func (p *PDUStreamProvider) CompleteSync(
req.Log.WithError(err).Error("unable to update event filter with ignored users") req.Log.WithError(err).Error("unable to update event filter with ignored users")
} }
recentEvents, err := snapshot.RecentEvents(ctx, joinedRoomIDs, r, &eventFilter, true, true)
if err != nil {
return from
}
// Build up a /sync response. Add joined rooms.
for _, roomID := range joinedRoomIDs {
events := recentEvents[roomID]
// Invalidate the lazyLoadCache, otherwise we end up with missing displaynames/avatars // Invalidate the lazyLoadCache, otherwise we end up with missing displaynames/avatars
// TODO: This might be inefficient, when joined to many and/or large rooms. // TODO: This might be inefficient, when joined to many and/or large rooms.
for _, roomID := range joinedRoomIDs {
joinedUsers := p.notifier.JoinedUsers(roomID) joinedUsers := p.notifier.JoinedUsers(roomID)
for _, sharedUser := range joinedUsers { for _, sharedUser := range joinedUsers {
p.lazyLoadCache.InvalidateLazyLoadedUser(req.Device, roomID, sharedUser) p.lazyLoadCache.InvalidateLazyLoadedUser(req.Device, roomID, sharedUser)
} }
}
// Build up a /sync response. Add joined rooms. // get the join response for each room
for _, roomID := range joinedRoomIDs {
jr, jerr := p.getJoinResponseForCompleteSync( jr, jerr := p.getJoinResponseForCompleteSync(
ctx, snapshot, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, false, ctx, snapshot, roomID, &stateFilter, req.WantFullState, req.Device, false,
events.Events, events.Limited,
) )
if jerr != nil { if jerr != nil {
req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed") req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed")
@ -113,11 +118,25 @@ func (p *PDUStreamProvider) CompleteSync(
req.Log.WithError(err).Error("p.DB.PeeksInRange failed") req.Log.WithError(err).Error("p.DB.PeeksInRange failed")
return from return from
} }
if len(peeks) > 0 {
peekRooms := make([]string, 0, len(peeks))
for _, peek := range peeks { for _, peek := range peeks {
if !peek.Deleted { if !peek.Deleted {
peekRooms = append(peekRooms, peek.RoomID)
}
}
recentEvents, err = snapshot.RecentEvents(ctx, peekRooms, r, &eventFilter, true, true)
if err != nil {
return from
}
for _, roomID := range peekRooms {
var jr *types.JoinResponse var jr *types.JoinResponse
events := recentEvents[roomID]
jr, err = p.getJoinResponseForCompleteSync( jr, err = p.getJoinResponseForCompleteSync(
ctx, snapshot, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, true, ctx, snapshot, roomID, &stateFilter, req.WantFullState, req.Device, true,
events.Events, events.Limited,
) )
if err != nil { if err != nil {
req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed") req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed")
@ -126,7 +145,7 @@ func (p *PDUStreamProvider) CompleteSync(
} }
continue continue
} }
req.Response.Rooms.Peek[peek.RoomID] = jr req.Response.Rooms.Peek[roomID] = jr
} }
} }
@ -227,7 +246,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
req *types.SyncRequest, req *types.SyncRequest,
) (types.StreamPosition, error) { ) (types.StreamPosition, error) {
var err error
originalLimit := eventFilter.Limit originalLimit := eventFilter.Limit
// If we're going backwards, grep at least X events, this is mostly to satisfy Sytest // If we're going backwards, grep at least X events, this is mostly to satisfy Sytest
if r.Backwards && originalLimit < recentEventBackwardsLimit { if r.Backwards && originalLimit < recentEventBackwardsLimit {
@ -238,8 +257,8 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
} }
} }
recentStreamEvents, limited, err := snapshot.RecentEvents( dbEvents, err := snapshot.RecentEvents(
ctx, delta.RoomID, r, ctx, []string{delta.RoomID}, r,
eventFilter, true, true, eventFilter, true, true,
) )
if err != nil { if err != nil {
@ -248,6 +267,10 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
} }
return r.From, fmt.Errorf("p.DB.RecentEvents: %w", err) return r.From, fmt.Errorf("p.DB.RecentEvents: %w", err)
} }
recentStreamEvents := dbEvents[delta.RoomID].Events
limited := dbEvents[delta.RoomID].Limited
recentEvents := gomatrixserverlib.HeaderedReverseTopologicalOrdering( recentEvents := gomatrixserverlib.HeaderedReverseTopologicalOrdering(
snapshot.StreamEventsToEvents(device, recentStreamEvents), snapshot.StreamEventsToEvents(device, recentStreamEvents),
gomatrixserverlib.TopologicalOrderByPrevEvents, gomatrixserverlib.TopologicalOrderByPrevEvents,
@ -420,7 +443,7 @@ func applyHistoryVisibilityFilter(
"room_id": roomID, "room_id": roomID,
"before": len(recentEvents), "before": len(recentEvents),
"after": len(events), "after": len(events),
}).Trace("Applied history visibility (sync)") }).Debugf("Applied history visibility (sync)")
return events, nil return events, nil
} }
@ -428,25 +451,16 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseTransaction, snapshot storage.DatabaseTransaction,
roomID string, roomID string,
r types.Range,
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
eventFilter *gomatrixserverlib.RoomEventFilter,
wantFullState bool, wantFullState bool,
device *userapi.Device, device *userapi.Device,
isPeek bool, isPeek bool,
recentStreamEvents []types.StreamEvent,
limited bool,
) (jr *types.JoinResponse, err error) { ) (jr *types.JoinResponse, err error) {
jr = types.NewJoinResponse() jr = types.NewJoinResponse()
// TODO: When filters are added, we may need to call this multiple times to get enough events. // TODO: When filters are added, we may need to call this multiple times to get enough events.
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
recentStreamEvents, limited, err := snapshot.RecentEvents(
ctx, roomID, r, eventFilter, true, true,
)
if err != nil {
if err == sql.ErrNoRows {
return jr, nil
}
return
}
// Work our way through the timeline events and pick out the event IDs // Work our way through the timeline events and pick out the event IDs
// of any state events that appear in the timeline. We'll specifically // of any state events that appear in the timeline. We'll specifically

View file

@ -10,6 +10,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/matrix-org/dendrite/syncapi/routing"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
@ -448,6 +449,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) {
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/messages", room.ID), test.WithQueryParams(map[string]string{ base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/messages", room.ID), test.WithQueryParams(map[string]string{
"access_token": bobDev.AccessToken, "access_token": bobDev.AccessToken,
"dir": "b", "dir": "b",
"filter": `{"lazy_load_members":true}`, // check that lazy loading doesn't break history visibility
}))) })))
if w.Code != 200 { if w.Code != 200 {
t.Logf("%s", w.Body.String()) t.Logf("%s", w.Body.String())
@ -905,6 +907,177 @@ func testSendToDevice(t *testing.T, dbType test.DBType) {
} }
} }
func TestContext(t *testing.T) {
test.WithAllDatabases(t, testContext)
}
func testContext(t *testing.T, dbType test.DBType) {
tests := []struct {
name string
roomID string
eventID string
params map[string]string
wantError bool
wantStateLength int
wantBeforeLength int
wantAfterLength int
}{
{
name: "invalid filter",
params: map[string]string{
"filter": "{",
},
wantError: true,
},
{
name: "invalid limit",
params: map[string]string{
"limit": "abc",
},
wantError: true,
},
{
name: "high limit",
params: map[string]string{
"limit": "100000",
},
},
{
name: "fine limit",
params: map[string]string{
"limit": "10",
},
},
{
name: "last event without lazy loading",
wantStateLength: 5,
},
{
name: "last event with lazy loading",
params: map[string]string{
"filter": `{"lazy_load_members":true}`,
},
wantStateLength: 1,
},
{
name: "invalid room",
roomID: "!doesnotexist",
wantError: true,
},
{
name: "invalid eventID",
eventID: "$doesnotexist",
wantError: true,
},
{
name: "state is limited",
params: map[string]string{
"limit": "1",
},
wantStateLength: 1,
},
{
name: "events are not limited",
wantBeforeLength: 7,
},
{
name: "all events are limited",
params: map[string]string{
"limit": "1",
},
wantStateLength: 1,
wantBeforeLength: 1,
wantAfterLength: 1,
},
}
user := test.NewUser(t)
alice := userapi.Device{
ID: "ALICEID",
UserID: user.ID,
AccessToken: "ALICE_BEARER_TOKEN",
DisplayName: "Alice",
AccountType: userapi.AccountTypeUser,
}
base, baseClose := testrig.CreateBaseDendrite(t, dbType)
defer baseClose()
// Use an actual roomserver for this
rsAPI := roomserver.NewInternalAPI(base)
rsAPI.SetFederationAPI(nil, nil)
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, rsAPI, &syncKeyAPI{})
room := test.NewRoom(t, user)
room.CreateAndInsert(t, user, "m.room.message", map[string]interface{}{"body": "hello world 1!"})
room.CreateAndInsert(t, user, "m.room.message", map[string]interface{}{"body": "hello world 2!"})
thirdMsg := room.CreateAndInsert(t, user, "m.room.message", map[string]interface{}{"body": "hello world3!"})
room.CreateAndInsert(t, user, "m.room.message", map[string]interface{}{"body": "hello world4!"})
if err := api.SendEvents(context.Background(), rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
t.Fatalf("failed to send events: %v", err)
}
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
syncUntil(t, base, alice.AccessToken, false, func(syncBody string) bool {
// wait for the last sent eventID to come down sync
path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(event_id=="%s")`, room.ID, thirdMsg.EventID())
return gjson.Get(syncBody, path).Exists()
})
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
params := map[string]string{
"access_token": alice.AccessToken,
}
w := httptest.NewRecorder()
// test overrides
roomID := room.ID
if tc.roomID != "" {
roomID = tc.roomID
}
eventID := thirdMsg.EventID()
if tc.eventID != "" {
eventID = tc.eventID
}
requestPath := fmt.Sprintf("/_matrix/client/v3/rooms/%s/context/%s", roomID, eventID)
if tc.params != nil {
for k, v := range tc.params {
params[k] = v
}
}
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", requestPath, test.WithQueryParams(params)))
if tc.wantError && w.Code == 200 {
t.Fatalf("Expected an error, but got none")
}
t.Log(w.Body.String())
resp := routing.ContextRespsonse{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if tc.wantStateLength > 0 && tc.wantStateLength != len(resp.State) {
t.Fatalf("expected %d state events, got %d", tc.wantStateLength, len(resp.State))
}
if tc.wantBeforeLength > 0 && tc.wantBeforeLength != len(resp.EventsBefore) {
t.Fatalf("expected %d before events, got %d", tc.wantBeforeLength, len(resp.EventsBefore))
}
if tc.wantAfterLength > 0 && tc.wantAfterLength != len(resp.EventsAfter) {
t.Fatalf("expected %d after events, got %d", tc.wantAfterLength, len(resp.EventsAfter))
}
if !tc.wantError && resp.Event.EventID != eventID {
t.Fatalf("unexpected eventID %s, expected %s", resp.Event.EventID, eventID)
}
})
}
}
func syncUntil(t *testing.T, func syncUntil(t *testing.T,
base *base.BaseDendrite, accessToken string, base *base.BaseDendrite, accessToken string,
skip bool, skip bool,

View file

@ -63,6 +63,11 @@ type StreamEvent struct {
ExcludeFromSync bool ExcludeFromSync bool
} }
type RecentEvents struct {
Limited bool
Events []StreamEvent
}
// Range represents a range between two stream positions. // Range represents a range between two stream positions.
type Range struct { type Range struct {
// From is the position the client has already received. // From is the position the client has already received.