Merge branch 'matrix-org-main'

This commit is contained in:
Brian Meek 2023-02-07 17:30:09 -08:00
commit c74798b5ab
No known key found for this signature in database
GPG key ID: ACBD71263BF42D00
28 changed files with 843 additions and 221 deletions

View file

@ -273,7 +273,11 @@ jobs:
- name: Setup go
uses: actions/setup-go@v3
with:
<<<<<<< HEAD
go-version: "1.19"
=======
go-version: "1.18"
>>>>>>> eb29a315507f0075c2c6a495ac59c64a7f45f9fc
cache: true
- name: Build upgrade-tests
run: go build ./cmd/dendrite-upgrade-tests
@ -293,7 +297,11 @@ jobs:
- name: Setup go
uses: actions/setup-go@v3
with:
<<<<<<< HEAD
go-version: "1.19"
=======
go-version: "1.18"
>>>>>>> eb29a315507f0075c2c6a495ac59c64a7f45f9fc
cache: true
- name: Build upgrade-tests
run: go build ./cmd/dendrite-upgrade-tests

View file

@ -433,7 +433,7 @@ func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error {
return nil
}
// QueryMembershipAtEventRequest requests the membership events for a user
// QueryMembershipAtEventRequest requests the membership event for a user
// for a list of eventIDs.
type QueryMembershipAtEventRequest struct {
RoomID string
@ -443,9 +443,10 @@ type QueryMembershipAtEventRequest struct {
// QueryMembershipAtEventResponse is the response to QueryMembershipAtEventRequest.
type QueryMembershipAtEventResponse struct {
// Memberships is a map from eventID to a list of events (if any). Events that
// do not have known state will return an empty array here.
Memberships map[string][]*gomatrixserverlib.HeaderedEvent `json:"memberships"`
// Membership is a map from eventID to membership event. Events that
// do not have known state will return a nil event, resulting in a "leave" membership
// when calculating history visibility.
Membership map[string]*gomatrixserverlib.HeaderedEvent `json:"membership"`
}
// QueryLeftUsersRequest is a request to calculate users that we (the server) don't share a

View file

@ -21,6 +21,7 @@ import (
"errors"
"fmt"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
@ -216,7 +217,8 @@ func (r *Queryer) QueryMembershipAtEvent(
request *api.QueryMembershipAtEventRequest,
response *api.QueryMembershipAtEventResponse,
) error {
response.Memberships = make(map[string][]*gomatrixserverlib.HeaderedEvent)
response.Membership = make(map[string]*gomatrixserverlib.HeaderedEvent)
info, err := r.DB.RoomInfo(ctx, request.RoomID)
if err != nil {
return fmt.Errorf("unable to get roomInfo: %w", err)
@ -234,7 +236,17 @@ func (r *Queryer) QueryMembershipAtEvent(
return fmt.Errorf("requested stateKeyNID for %s was not found", request.UserID)
}
stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, info, request.EventIDs, stateKeyNIDs[request.UserID])
response.Membership, err = r.DB.GetMembershipForHistoryVisibility(ctx, stateKeyNIDs[request.UserID], info, request.EventIDs...)
switch err {
case nil:
return nil
case tables.OptimisationNotSupportedError: // fallthrough, slow way of getting the membership events for each event
default:
return err
}
response.Membership = make(map[string]*gomatrixserverlib.HeaderedEvent)
stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID])
if err != nil {
return fmt.Errorf("unable to get state before event: %w", err)
}
@ -258,7 +270,7 @@ func (r *Queryer) QueryMembershipAtEvent(
for _, eventID := range request.EventIDs {
stateEntry, ok := stateEntries[eventID]
if !ok || len(stateEntry) == 0 {
response.Memberships[eventID] = []*gomatrixserverlib.HeaderedEvent{}
response.Membership[eventID] = nil
continue
}
@ -275,15 +287,15 @@ func (r *Queryer) QueryMembershipAtEvent(
return fmt.Errorf("unable to get memberships at state: %w", err)
}
res := make([]*gomatrixserverlib.HeaderedEvent, 0, len(memberships))
// Iterate over all membership events we got. Given we only query the membership for
// one user and assuming this user only ever has one membership event associated to
// a given event, overwrite any other existing membership events.
for i := range memberships {
ev := memberships[i]
if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(request.UserID) {
res = append(res, ev.Headered(info.RoomVersion))
response.Membership[eventID] = ev.Event.Headered(info.RoomVersion)
}
}
response.Memberships[eventID] = res
}
return nil

View file

@ -175,4 +175,11 @@ type Database interface {
GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error)
PurgeRoom(ctx context.Context, roomID string) error
UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error
// GetMembershipForHistoryVisibility queries the membership events for the given eventIDs.
// Returns a map from (input) eventID -> membership event. If no membership event is found, returns an empty event, resulting in
// a membership of "leave" when calculating history visibility.
GetMembershipForHistoryVisibility(
ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string,
) (map[string]*gomatrixserverlib.HeaderedEvent, error)
}

View file

@ -69,6 +69,9 @@ CREATE TABLE IF NOT EXISTS roomserver_events (
auth_event_nids BIGINT[] NOT NULL,
is_rejected BOOLEAN NOT NULL DEFAULT FALSE
);
-- Create an index which helps in resolving membership events (event_type_nid = 5) - (used for history visibility)
CREATE INDEX IF NOT EXISTS roomserver_events_memberships_idx ON roomserver_events (room_nid, event_state_key_nid) WHERE (event_type_nid = 5);
`
const insertEventSQL = "" +

View file

@ -21,10 +21,10 @@ import (
"fmt"
"github.com/lib/pq"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
)
@ -99,10 +99,26 @@ const bulkSelectStateForHistoryVisibilitySQL = `
AND (event_type_nid = 7 OR event_state_key LIKE '%:' || $2);
`
// bulkSelectMembershipForHistoryVisibilitySQL is an optimization to get membership events for a specific user for defined set of events.
// Returns the event_id of the event we want the membership event for, the event_id of the membership event and the membership event JSON.
const bulkSelectMembershipForHistoryVisibilitySQL = `
SELECT re.event_id, re2.event_id, rej.event_json
FROM roomserver_events re
LEFT JOIN roomserver_state_snapshots rss on re.state_snapshot_nid = rss.state_snapshot_nid
CROSS JOIN unnest(rss.state_block_nids) AS blocks(block_nid)
LEFT JOIN roomserver_state_block rsb ON rsb.state_block_nid = blocks.block_nid
CROSS JOIN unnest(rsb.event_nids) AS rsb2(event_nid)
JOIN roomserver_events re2 ON re2.room_nid = $3 AND re2.event_type_nid = 5 AND re2.event_nid = rsb2.event_nid AND re2.event_state_key_nid = $1
LEFT JOIN roomserver_event_json rej ON rej.event_nid = re2.event_nid
WHERE re.event_id = ANY($2)
`
type stateSnapshotStatements struct {
insertStateStmt *sql.Stmt
bulkSelectStateBlockNIDsStmt *sql.Stmt
bulkSelectStateForHistoryVisibilityStmt *sql.Stmt
bulktSelectMembershipForHistoryVisibilityStmt *sql.Stmt
}
func CreateStateSnapshotTable(db *sql.DB) error {
@ -110,13 +126,14 @@ func CreateStateSnapshotTable(db *sql.DB) error {
return err
}
func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
func PrepareStateSnapshotTable(db *sql.DB) (*stateSnapshotStatements, error) {
s := &stateSnapshotStatements{}
return s, sqlutil.StatementList{
{&s.insertStateStmt, insertStateSQL},
{&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
{&s.bulkSelectStateForHistoryVisibilityStmt, bulkSelectStateForHistoryVisibilitySQL},
{&s.bulktSelectMembershipForHistoryVisibilityStmt, bulkSelectMembershipForHistoryVisibilitySQL},
}.Prepare(db)
}
@ -185,3 +202,45 @@ func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility(
}
return results, rows.Err()
}
func (s *stateSnapshotStatements) BulkSelectMembershipForHistoryVisibility(
ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string,
) (map[string]*gomatrixserverlib.HeaderedEvent, error) {
stmt := sqlutil.TxStmt(txn, s.bulktSelectMembershipForHistoryVisibilityStmt)
rows, err := stmt.QueryContext(ctx, userNID, pq.Array(eventIDs), roomInfo.RoomNID)
if err != nil {
return nil, err
}
defer rows.Close() // nolint: errcheck
result := make(map[string]*gomatrixserverlib.HeaderedEvent, len(eventIDs))
var evJson []byte
var eventID string
var membershipEventID string
knownEvents := make(map[string]*gomatrixserverlib.HeaderedEvent, len(eventIDs))
for rows.Next() {
if err = rows.Scan(&eventID, &membershipEventID, &evJson); err != nil {
return nil, err
}
if len(evJson) == 0 {
result[eventID] = &gomatrixserverlib.HeaderedEvent{}
continue
}
// If we already know this event, don't try to marshal the json again
if ev, ok := knownEvents[membershipEventID]; ok {
result[eventID] = ev
continue
}
event, err := gomatrixserverlib.NewEventFromTrustedJSON(evJson, false, roomInfo.RoomVersion)
if err != nil {
result[eventID] = &gomatrixserverlib.HeaderedEvent{}
// not fatal
continue
}
he := event.Headered(roomInfo.RoomVersion)
result[eventID] = he
knownEvents[membershipEventID] = he
}
return result, rows.Err()
}

View file

@ -51,6 +51,12 @@ func (d *Database) SupportsConcurrentRoomInputs() bool {
return true
}
func (d *Database) GetMembershipForHistoryVisibility(
ctx context.Context, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string,
) (map[string]*gomatrixserverlib.HeaderedEvent, error) {
return d.StateSnapshotTable.BulkSelectMembershipForHistoryVisibility(ctx, nil, userNID, roomInfo, eventIDs...)
}
func (d *Database) EventTypeNIDs(
ctx context.Context, eventTypes []string,
) (map[string]types.EventTypeNID, error) {

View file

@ -26,6 +26,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
@ -152,6 +153,10 @@ func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility(
return nil, tables.OptimisationNotSupportedError
}
func (s *stateSnapshotStatements) BulkSelectMembershipForHistoryVisibility(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string) (map[string]*gomatrixserverlib.HeaderedEvent, error) {
return nil, tables.OptimisationNotSupportedError
}
func (s *stateSnapshotStatements) selectStateBlockNIDsForRoomNID(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
) ([]types.StateBlockNID, error) {

View file

@ -91,6 +91,10 @@ type StateSnapshot interface {
// which users are in a room faster than having to load the entire room state. In the
// case of SQLite, this will return tables.OptimisationNotSupportedError.
BulkSelectStateForHistoryVisibility(ctx context.Context, txn *sql.Tx, stateSnapshotNID types.StateSnapshotNID, domain string) ([]types.EventNID, error)
BulkSelectMembershipForHistoryVisibility(
ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string,
) (map[string]*gomatrixserverlib.HeaderedEvent, error)
}
type StateBlock interface {

View file

@ -29,6 +29,8 @@ func mustCreateStateSnapshotTable(t *testing.T, dbType test.DBType) (tab tables.
assert.NoError(t, err)
err = postgres.CreateEventsTable(db)
assert.NoError(t, err)
err = postgres.CreateEventJSONTable(db)
assert.NoError(t, err)
err = postgres.CreateStateBlockTable(db)
assert.NoError(t, err)
// ... and then the snapshot table itself

View file

@ -121,10 +121,7 @@ func ApplyHistoryVisibilityFilter(
// Get the mapping from eventID -> eventVisibility
eventsFiltered := make([]*gomatrixserverlib.HeaderedEvent, 0, len(events))
visibilities, err := visibilityForEvents(ctx, rsAPI, events, userID, events[0].RoomID())
if err != nil {
return eventsFiltered, err
}
visibilities := visibilityForEvents(ctx, rsAPI, events, userID, events[0].RoomID())
for _, ev := range events {
evVis := visibilities[ev.EventID()]
evVis.membershipCurrent = membershipCurrent
@ -175,7 +172,7 @@ func visibilityForEvents(
rsAPI api.SyncRoomserverAPI,
events []*gomatrixserverlib.HeaderedEvent,
userID, roomID string,
) (map[string]eventVisibility, error) {
) map[string]eventVisibility {
eventIDs := make([]string, len(events))
for i := range events {
eventIDs[i] = events[i].EventID()
@ -185,6 +182,7 @@ func visibilityForEvents(
// get the membership events for all eventIDs
membershipResp := &api.QueryMembershipAtEventResponse{}
err := rsAPI.QueryMembershipAtEvent(ctx, &api.QueryMembershipAtEventRequest{
RoomID: roomID,
EventIDs: eventIDs,
@ -201,19 +199,20 @@ func visibilityForEvents(
membershipAtEvent: gomatrixserverlib.Leave, // default to leave, to not expose events by accident
visibility: event.Visibility,
}
membershipEvs, ok := membershipResp.Memberships[eventID]
if !ok {
ev, ok := membershipResp.Membership[eventID]
if !ok || ev == nil {
result[eventID] = vis
continue
}
for _, ev := range membershipEvs {
membership, err := ev.Membership()
if err != nil {
return result, err
result[eventID] = vis
continue
}
vis.membershipAtEvent = membership
}
result[eventID] = vis
}
return result, nil
return result
}

View file

@ -67,6 +67,8 @@ func Context(
errMsg = "unable to parse filter"
case *strconv.NumError:
errMsg = "unable to parse limit"
default:
errMsg = err.Error()
}
return util.JSONResponse{
Code: http.StatusBadRequest,
@ -167,7 +169,18 @@ func Context(
eventsBeforeClient := gomatrixserverlib.HeaderedToClientEvents(eventsBeforeFiltered, gomatrixserverlib.FormatAll)
eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfterFiltered, gomatrixserverlib.FormatAll)
newState := applyLazyLoadMembers(device, filter, eventsAfterClient, eventsBeforeClient, state, lazyLoadCache)
newState := state
if filter.LazyLoadMembers {
allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...)
allEvents = append(allEvents, &requestedEvent)
evs := gomatrixserverlib.HeaderedToClientEvents(allEvents, gomatrixserverlib.FormatAll)
newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache)
if err != nil {
logrus.WithError(err).Error("unable to load membership events")
return jsonerror.InternalServerError()
}
}
ev := gomatrixserverlib.HeaderedToClientEvent(&requestedEvent, gomatrixserverlib.FormatAll)
response := ContextRespsonse{
@ -244,41 +257,43 @@ func getStartEnd(ctx context.Context, snapshot storage.DatabaseTransaction, star
}
func applyLazyLoadMembers(
ctx context.Context,
device *userapi.Device,
filter *gomatrixserverlib.RoomEventFilter,
eventsAfter, eventsBefore []gomatrixserverlib.ClientEvent,
state []*gomatrixserverlib.HeaderedEvent,
snapshot storage.DatabaseTransaction,
roomID string,
events []gomatrixserverlib.ClientEvent,
lazyLoadCache caching.LazyLoadCache,
) []*gomatrixserverlib.HeaderedEvent {
if filter == nil || !filter.LazyLoadMembers {
return state
}
allEvents := append(eventsBefore, eventsAfter...)
x := make(map[string]struct{})
) ([]*gomatrixserverlib.HeaderedEvent, error) {
eventSenders := make(map[string]struct{})
// get members who actually send an event
for _, e := range allEvents {
for _, e := range events {
// Don't add membership events the client should already know about
if _, cached := lazyLoadCache.IsLazyLoadedUserCached(device, e.RoomID, e.Sender); cached {
continue
}
x[e.Sender] = struct{}{}
eventSenders[e.Sender] = struct{}{}
}
newState := []*gomatrixserverlib.HeaderedEvent{}
membershipEvents := []*gomatrixserverlib.HeaderedEvent{}
for _, event := range state {
if event.Type() != gomatrixserverlib.MRoomMember {
newState = append(newState, event)
} else {
// did the user send an event?
if _, ok := x[event.Sender()]; ok {
membershipEvents = append(membershipEvents, event)
lazyLoadCache.StoreLazyLoadedUser(device, event.RoomID(), event.Sender(), event.EventID())
wantUsers := make([]string, 0, len(eventSenders))
for userID := range eventSenders {
wantUsers = append(wantUsers, userID)
}
// Query missing membership events
filter := gomatrixserverlib.DefaultStateFilter()
filter.Senders = &wantUsers
filter.Types = &[]string{gomatrixserverlib.MRoomMember}
memberships, err := snapshot.GetStateEventsForRoom(ctx, roomID, &filter)
if err != nil {
return nil, err
}
// cache the membership events
for _, membership := range memberships {
lazyLoadCache.StoreLazyLoadedUser(device, roomID, *membership.StateKey(), membership.EventID())
}
// Add the membershipEvents to the end of the list, to make Sytest happy
return append(newState, membershipEvents...)
return memberships, nil
}
func parseRoomEventFilter(req *http.Request) (*gomatrixserverlib.RoomEventFilter, error) {

View file

@ -64,6 +64,7 @@ type messagesResp struct {
// OnIncomingMessagesRequest implements the /messages endpoint from the
// client-server API.
// See: https://matrix.org/docs/spec/client_server/latest.html#get-matrix-client-r0-rooms-roomid-messages
// nolint:gocyclo
func OnIncomingMessagesRequest(
req *http.Request, db storage.Database, roomID string, device *userapi.Device,
rsAPI api.SyncRoomserverAPI,
@ -246,7 +247,14 @@ func OnIncomingMessagesRequest(
Start: start.String(),
End: end.String(),
}
res.applyLazyLoadMembers(req.Context(), snapshot, roomID, device, filter.LazyLoadMembers, lazyLoadCache)
if filter.LazyLoadMembers {
membershipEvents, err := applyLazyLoadMembers(req.Context(), device, snapshot, roomID, clientEvents, lazyLoadCache)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("failed to apply lazy loading")
return jsonerror.InternalServerError()
}
res.State = append(res.State, gomatrixserverlib.HeaderedToClientEvents(membershipEvents, gomatrixserverlib.FormatAll)...)
}
// If we didn't return any events, set the end to an empty string, so it will be omitted
// in the response JSON.
@ -265,40 +273,6 @@ func OnIncomingMessagesRequest(
}
}
// applyLazyLoadMembers loads membership events for users returned in Chunk, if the filter has
// LazyLoadMembers enabled.
func (m *messagesResp) applyLazyLoadMembers(
ctx context.Context,
db storage.DatabaseTransaction,
roomID string,
device *userapi.Device,
lazyLoad bool,
lazyLoadCache caching.LazyLoadCache,
) {
if !lazyLoad {
return
}
membershipToUser := make(map[string]*gomatrixserverlib.HeaderedEvent)
for _, evt := range m.Chunk {
// Don't add membership events the client should already know about
if _, cached := lazyLoadCache.IsLazyLoadedUserCached(device, roomID, evt.Sender); cached {
continue
}
membership, err := db.GetStateEvent(ctx, roomID, gomatrixserverlib.MRoomMember, evt.Sender)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("failed to get membership event for user")
continue
}
if membership != nil {
membershipToUser[evt.Sender] = membership
lazyLoadCache.StoreLazyLoadedUser(device, roomID, evt.Sender, membership.EventID())
}
}
for _, evt := range membershipToUser {
m.State = append(m.State, gomatrixserverlib.HeaderedToClientEvent(evt, gomatrixserverlib.FormatAll))
}
}
func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (resp api.QueryMembershipForUserResponse, err error) {
req := api.QueryMembershipForUserRequest{
RoomID: roomID,

View file

@ -46,7 +46,7 @@ type DatabaseTransaction interface {
RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error)
MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error)
GetRoomSummary(ctx context.Context, roomID, userID string) (summary *types.Summary, err error)
RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
RecentEvents(ctx context.Context, roomIDs []string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) (map[string]types.RecentEvents, error)
GetBackwardTopologyPos(ctx context.Context, events []*gomatrixserverlib.HeaderedEvent) (types.TopologyToken, error)
PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error)
InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error)

View file

@ -275,6 +275,15 @@ func (s *currentRoomStateStatements) SelectCurrentState(
) ([]*gomatrixserverlib.HeaderedEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt)
senders, notSenders := getSendersStateFilterFilter(stateFilter)
// We're going to query members later, so remove them from this request
if stateFilter.LazyLoadMembers && !stateFilter.IncludeRedundantMembers {
notTypes := &[]string{gomatrixserverlib.MRoomMember}
if stateFilter.NotTypes != nil {
*stateFilter.NotTypes = append(*stateFilter.NotTypes, gomatrixserverlib.MRoomMember)
} else {
stateFilter.NotTypes = notTypes
}
}
rows, err := stmt.QueryContext(ctx, roomID,
pq.StringArray(senders),
pq.StringArray(notSenders),

View file

@ -0,0 +1,29 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"context"
"database/sql"
"fmt"
)
func UpRenameOutputRoomEventsIndex(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `ALTER TABLE syncapi_output_room_events RENAME CONSTRAINT syncapi_event_id_idx TO syncapi_output_room_event_id_idx;`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}

View file

@ -19,18 +19,17 @@ import (
"context"
"database/sql"
"encoding/json"
"fmt"
"sort"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/lib/pq"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
const outputRoomEventsSchema = `
@ -44,7 +43,7 @@ CREATE TABLE IF NOT EXISTS syncapi_output_room_events (
-- This isn't a problem for us since we just want to order by this field.
id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_stream_id'),
-- The event ID for the event
event_id TEXT NOT NULL CONSTRAINT syncapi_event_id_idx UNIQUE,
event_id TEXT NOT NULL CONSTRAINT syncapi_output_room_event_id_idx UNIQUE,
-- The 'room_id' key for the event.
room_id TEXT NOT NULL,
-- The headered JSON for the event, containing potentially additional metadata such as
@ -79,13 +78,16 @@ CREATE INDEX IF NOT EXISTS syncapi_output_room_events_room_id_idx ON syncapi_out
CREATE INDEX IF NOT EXISTS syncapi_output_room_events_exclude_from_sync_idx ON syncapi_output_room_events (exclude_from_sync);
CREATE INDEX IF NOT EXISTS syncapi_output_room_events_add_state_ids_idx ON syncapi_output_room_events ((add_state_ids IS NOT NULL));
CREATE INDEX IF NOT EXISTS syncapi_output_room_events_remove_state_ids_idx ON syncapi_output_room_events ((remove_state_ids IS NOT NULL));
CREATE INDEX IF NOT EXISTS syncapi_output_room_events_recent_events_idx ON syncapi_output_room_events (room_id, exclude_from_sync, id, sender, type);
`
const insertEventSQL = "" +
"INSERT INTO syncapi_output_room_events (" +
"room_id, event_id, headered_event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync, history_visibility" +
") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) " +
"ON CONFLICT ON CONSTRAINT syncapi_event_id_idx DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $11) " +
"ON CONFLICT ON CONSTRAINT syncapi_output_room_event_id_idx DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $11) " +
"RETURNING id"
const selectEventsSQL = "" +
@ -109,14 +111,29 @@ const selectRecentEventsSQL = "" +
" AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" +
" ORDER BY id DESC LIMIT $8"
const selectRecentEventsForSyncSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" +
" AND ( $4::text[] IS NULL OR sender = ANY($4) )" +
" AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" +
" AND ( $6::text[] IS NULL OR type LIKE ANY($6) )" +
" AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" +
" ORDER BY id DESC LIMIT $8"
// selectRecentEventsForSyncSQL contains an optimization to get the recent events for a list of rooms, using a LATERAL JOIN
// The sub select inside LATERAL () is executed for all room_ids it gets as a parameter $1
const selectRecentEventsForSyncSQL = `
WITH room_ids AS (
SELECT unnest($1::text[]) AS room_id
)
SELECT x.*
FROM room_ids,
LATERAL (
SELECT room_id, event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility
FROM syncapi_output_room_events recent_events
WHERE
recent_events.room_id = room_ids.room_id
AND recent_events.exclude_from_sync = FALSE
AND id > $2 AND id <= $3
AND ( $4::text[] IS NULL OR sender = ANY($4) )
AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )
AND ( $6::text[] IS NULL OR type LIKE ANY($6) )
AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )
ORDER BY recent_events.id DESC
LIMIT $8
) AS x
`
const selectEarlyEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" +
@ -207,12 +224,30 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
return nil, err
}
migrationName := "syncapi: rename dupe index (output_room_events)"
var cName string
err = db.QueryRowContext(context.Background(), "select constraint_name from information_schema.table_constraints where table_name = 'syncapi_output_room_events' AND constraint_name = 'syncapi_event_id_idx'").Scan(&cName)
switch err {
case sql.ErrNoRows: // migration was already executed, as the index was renamed
if err = sqlutil.InsertMigration(context.Background(), db, migrationName); err != nil {
return nil, fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err)
}
case nil:
default:
return nil, err
}
m := sqlutil.NewMigrator(db)
m.AddMigrations(
sqlutil.Migration{
Version: "syncapi: add history visibility column (output_room_events)",
Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents,
},
sqlutil.Migration{
Version: migrationName,
Up: deltas.UpRenameOutputRoomEventsIndex,
},
)
err = m.Up(context.Background())
if err != nil {
@ -398,9 +433,9 @@ func (s *outputRoomEventsStatements) InsertEvent(
// from sync.
func (s *outputRoomEventsStatements) SelectRecentEvents(
ctx context.Context, txn *sql.Tx,
roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
roomIDs []string, ra types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
chronologicalOrder bool, onlySyncEvents bool,
) ([]types.StreamEvent, bool, error) {
) (map[string]types.RecentEvents, error) {
var stmt *sql.Stmt
if onlySyncEvents {
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsForSyncStmt)
@ -408,8 +443,9 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt)
}
senders, notSenders := getSendersRoomEventFilter(eventFilter)
rows, err := stmt.QueryContext(
ctx, roomID, r.Low(), r.High(),
ctx, pq.StringArray(roomIDs), ra.Low(), ra.High(),
pq.StringArray(senders),
pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)),
@ -417,34 +453,80 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
eventFilter.Limit+1,
)
if err != nil {
return nil, false, err
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed")
events, err := rowsToStreamEvents(rows)
if err != nil {
return nil, false, err
result := make(map[string]types.RecentEvents)
for rows.Next() {
var (
roomID string
eventID string
streamPos types.StreamPosition
eventBytes []byte
excludeFromSync bool
sessionID *int64
txnID *string
transactionID *api.TransactionID
historyVisibility gomatrixserverlib.HistoryVisibility
)
if err := rows.Scan(&roomID, &eventID, &streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID, &historyVisibility); err != nil {
return nil, err
}
if chronologicalOrder {
// The events need to be returned from oldest to latest, which isn't
// necessary the way the SQL query returns them, so a sort is necessary to
// ensure the events are in the right order in the slice.
sort.SliceStable(events, func(i int, j int) bool {
return events[i].StreamPosition < events[j].StreamPosition
})
// TODO: Handle redacted events
var ev gomatrixserverlib.HeaderedEvent
if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil {
return nil, err
}
// we queried for 1 more than the limit, so if we returned one more mark limited=true
limited := false
if len(events) > eventFilter.Limit {
limited = true
// re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last.
if chronologicalOrder {
events = events[1:]
} else {
events = events[:len(events)-1]
if sessionID != nil && txnID != nil {
transactionID = &api.TransactionID{
SessionID: *sessionID,
TransactionID: *txnID,
}
}
return events, limited, nil
r := result[roomID]
ev.Visibility = historyVisibility
r.Events = append(r.Events, types.StreamEvent{
HeaderedEvent: &ev,
StreamPosition: streamPos,
TransactionID: transactionID,
ExcludeFromSync: excludeFromSync,
})
result[roomID] = r
}
if chronologicalOrder {
for roomID, evs := range result {
// The events need to be returned from oldest to latest, which isn't
// necessary the way the SQL query returns them, so a sort is necessary to
// ensure the events are in the right order in the slice.
sort.SliceStable(evs.Events, func(i int, j int) bool {
return evs.Events[i].StreamPosition < evs.Events[j].StreamPosition
})
if len(evs.Events) > eventFilter.Limit {
evs.Limited = true
evs.Events = evs.Events[1:]
}
result[roomID] = evs
}
} else {
for roomID, evs := range result {
if len(evs.Events) > eventFilter.Limit {
evs.Limited = true
evs.Events = evs.Events[:len(evs.Events)-1]
}
result[roomID] = evs
}
}
return result, rows.Err()
}
// selectEarlyEvents returns the earliest events in the given room, starting

View file

@ -151,8 +151,8 @@ func (d *DatabaseTransaction) GetRoomSummary(ctx context.Context, roomID, userID
return summary, nil
}
func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) {
return d.OutputEvents.SelectRecentEvents(ctx, d.txn, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents)
func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomIDs []string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) (map[string]types.RecentEvents, error) {
return d.OutputEvents.SelectRecentEvents(ctx, d.txn, roomIDs, r, eventFilter, chronologicalOrder, onlySyncEvents)
}
func (d *DatabaseTransaction) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) {
@ -370,20 +370,26 @@ func (d *DatabaseTransaction) GetStateDeltas(
}
// get all the state events ever (i.e. for all available rooms) between these two positions
stateNeededFiltered, eventMapFiltered, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs)
stateFiltered := state
// avoid hitting the database if the result would be the same as above
if !isStatefilterEmpty(stateFilter) {
var stateNeededFiltered map[string]map[string]bool
var eventMapFiltered map[string]types.StreamEvent
stateNeededFiltered, eventMapFiltered, err = d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
}
return nil, nil, err
}
stateFiltered, err := d.fetchStateEvents(ctx, d.txn, stateNeededFiltered, eventMapFiltered)
stateFiltered, err = d.fetchStateEvents(ctx, d.txn, stateNeededFiltered, eventMapFiltered)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
}
return nil, nil, err
}
}
// find out which rooms this user is peeking, if any.
// We do this before joins so any peeks get overwritten
@ -701,6 +707,28 @@ func (d *DatabaseTransaction) MaxStreamPositionForRelations(ctx context.Context)
return types.StreamPosition(id), err
}
func isStatefilterEmpty(filter *gomatrixserverlib.StateFilter) bool {
if filter == nil {
return true
}
switch {
case filter.NotTypes != nil && len(*filter.NotTypes) > 0:
return false
case filter.Types != nil && len(*filter.Types) > 0:
return false
case filter.Senders != nil && len(*filter.Senders) > 0:
return false
case filter.NotSenders != nil && len(*filter.NotSenders) > 0:
return false
case filter.NotRooms != nil && len(*filter.NotRooms) > 0:
return false
case filter.ContainsURL != nil:
return false
default:
return true
}
}
func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (
events []types.StreamEvent, prevBatch, nextBatch string, err error,
) {

View file

@ -0,0 +1,72 @@
package shared
import (
"testing"
"github.com/matrix-org/gomatrixserverlib"
)
func Test_isStatefilterEmpty(t *testing.T) {
filterSet := []string{"a"}
boolValue := false
tests := []struct {
name string
filter *gomatrixserverlib.StateFilter
want bool
}{
{
name: "nil filter is empty",
filter: nil,
want: true,
},
{
name: "Empty filter is empty",
filter: &gomatrixserverlib.StateFilter{},
want: true,
},
{
name: "NotTypes is set",
filter: &gomatrixserverlib.StateFilter{
NotTypes: &filterSet,
},
},
{
name: "Types is set",
filter: &gomatrixserverlib.StateFilter{
Types: &filterSet,
},
},
{
name: "Senders is set",
filter: &gomatrixserverlib.StateFilter{
Senders: &filterSet,
},
},
{
name: "NotSenders is set",
filter: &gomatrixserverlib.StateFilter{
NotSenders: &filterSet,
},
},
{
name: "NotRooms is set",
filter: &gomatrixserverlib.StateFilter{
NotRooms: &filterSet,
},
},
{
name: "ContainsURL is set",
filter: &gomatrixserverlib.StateFilter{
ContainsURL: &boolValue,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isStatefilterEmpty(tt.filter); got != tt.want {
t.Errorf("isStatefilterEmpty() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -105,7 +105,6 @@ func (s *accountDataStatements) SelectAccountDataInRange(
filter.Senders, filter.NotSenders,
filter.Types, filter.NotTypes,
[]string{}, nil, filter.Limit, FilterOrderAsc)
if err != nil {
return
}

View file

@ -267,6 +267,15 @@ func (s *currentRoomStateStatements) SelectCurrentState(
stateFilter *gomatrixserverlib.StateFilter,
excludeEventIDs []string,
) ([]*gomatrixserverlib.HeaderedEvent, error) {
// We're going to query members later, so remove them from this request
if stateFilter.LazyLoadMembers && !stateFilter.IncludeRedundantMembers {
notTypes := &[]string{gomatrixserverlib.MRoomMember}
if stateFilter.NotTypes != nil {
*stateFilter.NotTypes = append(*stateFilter.NotTypes, gomatrixserverlib.MRoomMember)
} else {
stateFilter.NotTypes = notTypes
}
}
stmt, params, err := prepareWithFilters(
s.db, txn, selectCurrentStateSQL,
[]interface{}{

View file

@ -368,9 +368,9 @@ func (s *outputRoomEventsStatements) InsertEvent(
func (s *outputRoomEventsStatements) SelectRecentEvents(
ctx context.Context, txn *sql.Tx,
roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
roomIDs []string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
chronologicalOrder bool, onlySyncEvents bool,
) ([]types.StreamEvent, bool, error) {
) (map[string]types.RecentEvents, error) {
var query string
if onlySyncEvents {
query = selectRecentEventsForSyncSQL
@ -378,6 +378,8 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
query = selectRecentEventsSQL
}
result := make(map[string]types.RecentEvents, len(roomIDs))
for _, roomID := range roomIDs {
stmt, params, err := prepareWithFilters(
s.db, txn, query,
[]interface{}{
@ -388,18 +390,18 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
nil, eventFilter.ContainsURL, eventFilter.Limit+1, FilterOrderDesc,
)
if err != nil {
return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err)
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
}
defer internal.CloseAndLogIfError(ctx, stmt, "selectRecentEvents: stmt.close() failed")
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, false, err
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed")
events, err := rowsToStreamEvents(rows)
if err != nil {
return nil, false, err
return nil, err
}
if chronologicalOrder {
// The events need to be returned from oldest to latest, which isn't
@ -409,10 +411,10 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
return events[i].StreamPosition < events[j].StreamPosition
})
}
res := types.RecentEvents{}
// we queried for 1 more than the limit, so if we returned one more mark limited=true
limited := false
if len(events) > eventFilter.Limit {
limited = true
res.Limited = true
// re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last.
if chronologicalOrder {
events = events[1:]
@ -420,7 +422,11 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
events = events[:len(events)-1]
}
}
return events, limited, nil
res.Events = events
result[roomID] = res
}
return result, nil
}
func (s *outputRoomEventsStatements) SelectEarlyEvents(

View file

@ -156,12 +156,12 @@ func TestRecentEventsPDU(t *testing.T) {
tc := testCases[i]
t.Run(tc.Name, func(st *testing.T) {
var filter gomatrixserverlib.RoomEventFilter
var gotEvents []types.StreamEvent
var gotEvents map[string]types.RecentEvents
var limited bool
filter.Limit = tc.Limit
WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
var err error
gotEvents, limited, err = snapshot.RecentEvents(ctx, r.ID, types.Range{
gotEvents, err = snapshot.RecentEvents(ctx, []string{r.ID}, types.Range{
From: tc.From,
To: tc.To,
}, &filter, !tc.ReverseOrder, true)
@ -169,15 +169,18 @@ func TestRecentEventsPDU(t *testing.T) {
st.Fatalf("failed to do sync: %s", err)
}
})
streamEvents := gotEvents[r.ID]
limited = streamEvents.Limited
if limited != tc.WantLimited {
st.Errorf("got limited=%v want %v", limited, tc.WantLimited)
}
if len(gotEvents) != len(tc.WantEvents) {
if len(streamEvents.Events) != len(tc.WantEvents) {
st.Errorf("got %d events, want %d", len(gotEvents), len(tc.WantEvents))
}
for j := range gotEvents {
if !reflect.DeepEqual(gotEvents[j].JSON(), tc.WantEvents[j].JSON()) {
st.Errorf("event %d got %s want %s", j, string(gotEvents[j].JSON()), string(tc.WantEvents[j].JSON()))
for j := range streamEvents.Events {
if !reflect.DeepEqual(streamEvents.Events[j].JSON(), tc.WantEvents[j].JSON()) {
st.Errorf("event %d got %s want %s", j, string(streamEvents.Events[j].JSON()), string(tc.WantEvents[j].JSON()))
}
}
})
@ -923,3 +926,61 @@ func TestRoomSummary(t *testing.T) {
}
})
}
func TestRecentEvents(t *testing.T) {
alice := test.NewUser(t)
room1 := test.NewRoom(t, alice)
room2 := test.NewRoom(t, alice)
roomIDs := []string{room1.ID, room2.ID}
rooms := map[string]*test.Room{
room1.ID: room1,
room2.ID: room2,
}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
filter := gomatrixserverlib.DefaultRoomEventFilter()
db, close, closeBase := MustCreateDatabase(t, dbType)
t.Cleanup(func() {
close()
closeBase()
})
MustWriteEvents(t, db, room1.Events())
MustWriteEvents(t, db, room2.Events())
transaction, err := db.NewDatabaseTransaction(ctx)
assert.NoError(t, err)
defer transaction.Rollback()
// get all recent events from 0 to 100 (we only created 5 events, so we should get 5 back)
roomEvs, err := transaction.RecentEvents(ctx, roomIDs, types.Range{From: 0, To: 100}, &filter, true, true)
assert.NoError(t, err)
assert.Equal(t, len(roomEvs), 2, "unexpected recent events response")
for _, recentEvents := range roomEvs {
assert.Equal(t, 5, len(recentEvents.Events), "unexpected recent events for room")
}
// update the filter to only return one event
filter.Limit = 1
roomEvs, err = transaction.RecentEvents(ctx, roomIDs, types.Range{From: 0, To: 100}, &filter, true, true)
assert.NoError(t, err)
assert.Equal(t, len(roomEvs), 2, "unexpected recent events response")
for roomID, recentEvents := range roomEvs {
origEvents := rooms[roomID].Events()
assert.Equal(t, true, recentEvents.Limited, "expected events to be limited")
assert.Equal(t, 1, len(recentEvents.Events), "unexpected recent events for room")
assert.Equal(t, origEvents[len(origEvents)-1].EventID(), recentEvents.Events[0].EventID())
}
// not chronologically ordered still returns the events in order (given ORDER BY id DESC)
roomEvs, err = transaction.RecentEvents(ctx, roomIDs, types.Range{From: 0, To: 100}, &filter, false, true)
assert.NoError(t, err)
assert.Equal(t, len(roomEvs), 2, "unexpected recent events response")
for roomID, recentEvents := range roomEvs {
origEvents := rooms[roomID].Events()
assert.Equal(t, true, recentEvents.Limited, "expected events to be limited")
assert.Equal(t, 1, len(recentEvents.Events), "unexpected recent events for room")
assert.Equal(t, origEvents[len(origEvents)-1].EventID(), recentEvents.Events[0].EventID())
}
})
}

View file

@ -13,6 +13,7 @@ import (
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
)
func newCurrentRoomStateTable(t *testing.T, dbType test.DBType) (tables.CurrentRoomState, *sql.DB, func()) {
@ -79,6 +80,9 @@ func TestCurrentRoomStateTable(t *testing.T) {
return fmt.Errorf("SelectEventsWithEventIDs\nexpected id %q not returned", id)
}
}
testCurrentState(t, ctx, txn, tab, room)
return nil
})
if err != nil {
@ -86,3 +90,39 @@ func TestCurrentRoomStateTable(t *testing.T) {
}
})
}
func testCurrentState(t *testing.T, ctx context.Context, txn *sql.Tx, tab tables.CurrentRoomState, room *test.Room) {
t.Run("test currentState", func(t *testing.T) {
// returns the complete state of the room with a default filter
filter := gomatrixserverlib.DefaultStateFilter()
evs, err := tab.SelectCurrentState(ctx, txn, room.ID, &filter, nil)
if err != nil {
t.Fatal(err)
}
expectCount := 5
if gotCount := len(evs); gotCount != expectCount {
t.Fatalf("expected %d state events, got %d", expectCount, gotCount)
}
// When lazy loading, we expect no membership event, so only 4 events
filter.LazyLoadMembers = true
expectCount = 4
evs, err = tab.SelectCurrentState(ctx, txn, room.ID, &filter, nil)
if err != nil {
t.Fatal(err)
}
if gotCount := len(evs); gotCount != expectCount {
t.Fatalf("expected %d state events, got %d", expectCount, gotCount)
}
// same as above, but with existing NotTypes defined
notTypes := []string{gomatrixserverlib.MRoomMember}
filter.NotTypes = &notTypes
evs, err = tab.SelectCurrentState(ctx, txn, room.ID, &filter, nil)
if err != nil {
t.Fatal(err)
}
if gotCount := len(evs); gotCount != expectCount {
t.Fatalf("expected %d state events, got %d", expectCount, gotCount)
}
})
}

View file

@ -66,7 +66,7 @@ type Events interface {
// SelectRecentEvents returns events between the two stream positions: exclusive of low and inclusive of high.
// If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude from sync.
// Returns up to `limit` events. Returns `limited=true` if there are more events in this range but we hit the `limit`.
SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomIDs []string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) (map[string]types.RecentEvents, error)
// SelectEarlyEvents returns the earliest events in the given room.
SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error)
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool) ([]types.StreamEvent, error)

View file

@ -82,19 +82,24 @@ func (p *PDUStreamProvider) CompleteSync(
req.Log.WithError(err).Error("unable to update event filter with ignored users")
}
recentEvents, err := snapshot.RecentEvents(ctx, joinedRoomIDs, r, &eventFilter, true, true)
if err != nil {
return from
}
// Build up a /sync response. Add joined rooms.
for _, roomID := range joinedRoomIDs {
events := recentEvents[roomID]
// Invalidate the lazyLoadCache, otherwise we end up with missing displaynames/avatars
// TODO: This might be inefficient, when joined to many and/or large rooms.
for _, roomID := range joinedRoomIDs {
joinedUsers := p.notifier.JoinedUsers(roomID)
for _, sharedUser := range joinedUsers {
p.lazyLoadCache.InvalidateLazyLoadedUser(req.Device, roomID, sharedUser)
}
}
// Build up a /sync response. Add joined rooms.
for _, roomID := range joinedRoomIDs {
// get the join response for each room
jr, jerr := p.getJoinResponseForCompleteSync(
ctx, snapshot, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, false,
ctx, snapshot, roomID, &stateFilter, req.WantFullState, req.Device, false,
events.Events, events.Limited,
)
if jerr != nil {
req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed")
@ -113,11 +118,25 @@ func (p *PDUStreamProvider) CompleteSync(
req.Log.WithError(err).Error("p.DB.PeeksInRange failed")
return from
}
if len(peeks) > 0 {
peekRooms := make([]string, 0, len(peeks))
for _, peek := range peeks {
if !peek.Deleted {
peekRooms = append(peekRooms, peek.RoomID)
}
}
recentEvents, err = snapshot.RecentEvents(ctx, peekRooms, r, &eventFilter, true, true)
if err != nil {
return from
}
for _, roomID := range peekRooms {
var jr *types.JoinResponse
events := recentEvents[roomID]
jr, err = p.getJoinResponseForCompleteSync(
ctx, snapshot, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, true,
ctx, snapshot, roomID, &stateFilter, req.WantFullState, req.Device, true,
events.Events, events.Limited,
)
if err != nil {
req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed")
@ -126,7 +145,7 @@ func (p *PDUStreamProvider) CompleteSync(
}
continue
}
req.Response.Rooms.Peek[peek.RoomID] = jr
req.Response.Rooms.Peek[roomID] = jr
}
}
@ -227,7 +246,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
stateFilter *gomatrixserverlib.StateFilter,
req *types.SyncRequest,
) (types.StreamPosition, error) {
var err error
originalLimit := eventFilter.Limit
// If we're going backwards, grep at least X events, this is mostly to satisfy Sytest
if r.Backwards && originalLimit < recentEventBackwardsLimit {
@ -238,8 +257,8 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
}
}
recentStreamEvents, limited, err := snapshot.RecentEvents(
ctx, delta.RoomID, r,
dbEvents, err := snapshot.RecentEvents(
ctx, []string{delta.RoomID}, r,
eventFilter, true, true,
)
if err != nil {
@ -248,6 +267,10 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
}
return r.From, fmt.Errorf("p.DB.RecentEvents: %w", err)
}
recentStreamEvents := dbEvents[delta.RoomID].Events
limited := dbEvents[delta.RoomID].Limited
recentEvents := gomatrixserverlib.HeaderedReverseTopologicalOrdering(
snapshot.StreamEventsToEvents(device, recentStreamEvents),
gomatrixserverlib.TopologicalOrderByPrevEvents,
@ -420,7 +443,7 @@ func applyHistoryVisibilityFilter(
"room_id": roomID,
"before": len(recentEvents),
"after": len(events),
}).Trace("Applied history visibility (sync)")
}).Debugf("Applied history visibility (sync)")
return events, nil
}
@ -428,25 +451,16 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
ctx context.Context,
snapshot storage.DatabaseTransaction,
roomID string,
r types.Range,
stateFilter *gomatrixserverlib.StateFilter,
eventFilter *gomatrixserverlib.RoomEventFilter,
wantFullState bool,
device *userapi.Device,
isPeek bool,
recentStreamEvents []types.StreamEvent,
limited bool,
) (jr *types.JoinResponse, err error) {
jr = types.NewJoinResponse()
// TODO: When filters are added, we may need to call this multiple times to get enough events.
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
recentStreamEvents, limited, err := snapshot.RecentEvents(
ctx, roomID, r, eventFilter, true, true,
)
if err != nil {
if err == sql.ErrNoRows {
return jr, nil
}
return
}
// Work our way through the timeline events and pick out the event IDs
// of any state events that appear in the timeline. We'll specifically

View file

@ -10,6 +10,7 @@ import (
"testing"
"time"
"github.com/matrix-org/dendrite/syncapi/routing"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/tidwall/gjson"
@ -448,6 +449,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) {
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/messages", room.ID), test.WithQueryParams(map[string]string{
"access_token": bobDev.AccessToken,
"dir": "b",
"filter": `{"lazy_load_members":true}`, // check that lazy loading doesn't break history visibility
})))
if w.Code != 200 {
t.Logf("%s", w.Body.String())
@ -905,6 +907,177 @@ func testSendToDevice(t *testing.T, dbType test.DBType) {
}
}
func TestContext(t *testing.T) {
test.WithAllDatabases(t, testContext)
}
func testContext(t *testing.T, dbType test.DBType) {
tests := []struct {
name string
roomID string
eventID string
params map[string]string
wantError bool
wantStateLength int
wantBeforeLength int
wantAfterLength int
}{
{
name: "invalid filter",
params: map[string]string{
"filter": "{",
},
wantError: true,
},
{
name: "invalid limit",
params: map[string]string{
"limit": "abc",
},
wantError: true,
},
{
name: "high limit",
params: map[string]string{
"limit": "100000",
},
},
{
name: "fine limit",
params: map[string]string{
"limit": "10",
},
},
{
name: "last event without lazy loading",
wantStateLength: 5,
},
{
name: "last event with lazy loading",
params: map[string]string{
"filter": `{"lazy_load_members":true}`,
},
wantStateLength: 1,
},
{
name: "invalid room",
roomID: "!doesnotexist",
wantError: true,
},
{
name: "invalid eventID",
eventID: "$doesnotexist",
wantError: true,
},
{
name: "state is limited",
params: map[string]string{
"limit": "1",
},
wantStateLength: 1,
},
{
name: "events are not limited",
wantBeforeLength: 7,
},
{
name: "all events are limited",
params: map[string]string{
"limit": "1",
},
wantStateLength: 1,
wantBeforeLength: 1,
wantAfterLength: 1,
},
}
user := test.NewUser(t)
alice := userapi.Device{
ID: "ALICEID",
UserID: user.ID,
AccessToken: "ALICE_BEARER_TOKEN",
DisplayName: "Alice",
AccountType: userapi.AccountTypeUser,
}
base, baseClose := testrig.CreateBaseDendrite(t, dbType)
defer baseClose()
// Use an actual roomserver for this
rsAPI := roomserver.NewInternalAPI(base)
rsAPI.SetFederationAPI(nil, nil)
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, rsAPI, &syncKeyAPI{})
room := test.NewRoom(t, user)
room.CreateAndInsert(t, user, "m.room.message", map[string]interface{}{"body": "hello world 1!"})
room.CreateAndInsert(t, user, "m.room.message", map[string]interface{}{"body": "hello world 2!"})
thirdMsg := room.CreateAndInsert(t, user, "m.room.message", map[string]interface{}{"body": "hello world3!"})
room.CreateAndInsert(t, user, "m.room.message", map[string]interface{}{"body": "hello world4!"})
if err := api.SendEvents(context.Background(), rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
t.Fatalf("failed to send events: %v", err)
}
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
syncUntil(t, base, alice.AccessToken, false, func(syncBody string) bool {
// wait for the last sent eventID to come down sync
path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(event_id=="%s")`, room.ID, thirdMsg.EventID())
return gjson.Get(syncBody, path).Exists()
})
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
params := map[string]string{
"access_token": alice.AccessToken,
}
w := httptest.NewRecorder()
// test overrides
roomID := room.ID
if tc.roomID != "" {
roomID = tc.roomID
}
eventID := thirdMsg.EventID()
if tc.eventID != "" {
eventID = tc.eventID
}
requestPath := fmt.Sprintf("/_matrix/client/v3/rooms/%s/context/%s", roomID, eventID)
if tc.params != nil {
for k, v := range tc.params {
params[k] = v
}
}
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", requestPath, test.WithQueryParams(params)))
if tc.wantError && w.Code == 200 {
t.Fatalf("Expected an error, but got none")
}
t.Log(w.Body.String())
resp := routing.ContextRespsonse{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if tc.wantStateLength > 0 && tc.wantStateLength != len(resp.State) {
t.Fatalf("expected %d state events, got %d", tc.wantStateLength, len(resp.State))
}
if tc.wantBeforeLength > 0 && tc.wantBeforeLength != len(resp.EventsBefore) {
t.Fatalf("expected %d before events, got %d", tc.wantBeforeLength, len(resp.EventsBefore))
}
if tc.wantAfterLength > 0 && tc.wantAfterLength != len(resp.EventsAfter) {
t.Fatalf("expected %d after events, got %d", tc.wantAfterLength, len(resp.EventsAfter))
}
if !tc.wantError && resp.Event.EventID != eventID {
t.Fatalf("unexpected eventID %s, expected %s", resp.Event.EventID, eventID)
}
})
}
}
func syncUntil(t *testing.T,
base *base.BaseDendrite, accessToken string,
skip bool,

View file

@ -63,6 +63,11 @@ type StreamEvent struct {
ExcludeFromSync bool
}
type RecentEvents struct {
Limited bool
Events []StreamEvent
}
// Range represents a range between two stream positions.
type Range struct {
// From is the position the client has already received.