Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/lazyloading

This commit is contained in:
Till Faelligen 2022-04-13 14:02:13 +02:00
commit 849cfeceac
13 changed files with 153 additions and 53 deletions

View file

@ -36,7 +36,7 @@ import (
type Notifier struct { type Notifier struct {
lock *sync.RWMutex lock *sync.RWMutex
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine // A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
roomIDToJoinedUsers map[string]userIDSet roomIDToJoinedUsers map[string]*userIDSet
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine // A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
roomIDToPeekingDevices map[string]peekingDeviceSet roomIDToPeekingDevices map[string]peekingDeviceSet
// The latest sync position // The latest sync position
@ -54,7 +54,7 @@ type Notifier struct {
// the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase). // the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase).
func NewNotifier() *Notifier { func NewNotifier() *Notifier {
return &Notifier{ return &Notifier{
roomIDToJoinedUsers: make(map[string]userIDSet), roomIDToJoinedUsers: make(map[string]*userIDSet),
roomIDToPeekingDevices: make(map[string]peekingDeviceSet), roomIDToPeekingDevices: make(map[string]peekingDeviceSet),
userDeviceStreams: make(map[string]map[string]*UserDeviceStream), userDeviceStreams: make(map[string]map[string]*UserDeviceStream),
lock: &sync.RWMutex{}, lock: &sync.RWMutex{},
@ -262,7 +262,7 @@ func (n *Notifier) SharedUsers(userID string) []string {
func (n *Notifier) _sharedUsers(userID string) []string { func (n *Notifier) _sharedUsers(userID string) []string {
n._sharedUserMap[userID] = struct{}{} n._sharedUserMap[userID] = struct{}{}
for roomID, users := range n.roomIDToJoinedUsers { for roomID, users := range n.roomIDToJoinedUsers {
if _, ok := users[userID]; !ok { if ok := users.isIn(userID); !ok {
continue continue
} }
for _, userID := range n._joinedUsers(roomID) { for _, userID := range n._joinedUsers(roomID) {
@ -282,8 +282,11 @@ func (n *Notifier) IsSharedUser(userA, userB string) bool {
defer n.lock.RUnlock() defer n.lock.RUnlock()
var okA, okB bool var okA, okB bool
for _, users := range n.roomIDToJoinedUsers { for _, users := range n.roomIDToJoinedUsers {
_, okA = users[userA] okA = users.isIn(userA)
_, okB = users[userB] if !okA {
continue
}
okB = users.isIn(userB)
if okA && okB { if okA && okB {
return true return true
} }
@ -345,11 +348,12 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) {
// This is just the bulk form of addJoinedUser // This is just the bulk form of addJoinedUser
for roomID, userIDs := range roomIDToUserIDs { for roomID, userIDs := range roomIDToUserIDs {
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok { if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
n.roomIDToJoinedUsers[roomID] = make(userIDSet, len(userIDs)) n.roomIDToJoinedUsers[roomID] = newUserIDSet(len(userIDs))
} }
for _, userID := range userIDs { for _, userID := range userIDs {
n.roomIDToJoinedUsers[roomID].add(userID) n.roomIDToJoinedUsers[roomID].add(userID)
} }
n.roomIDToJoinedUsers[roomID].precompute()
} }
} }
@ -440,16 +444,18 @@ func (n *Notifier) _fetchUserStreams(userID string) []*UserDeviceStream {
func (n *Notifier) _addJoinedUser(roomID, userID string) { func (n *Notifier) _addJoinedUser(roomID, userID string) {
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok { if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
n.roomIDToJoinedUsers[roomID] = make(userIDSet) n.roomIDToJoinedUsers[roomID] = newUserIDSet(8)
} }
n.roomIDToJoinedUsers[roomID].add(userID) n.roomIDToJoinedUsers[roomID].add(userID)
n.roomIDToJoinedUsers[roomID].precompute()
} }
func (n *Notifier) _removeJoinedUser(roomID, userID string) { func (n *Notifier) _removeJoinedUser(roomID, userID string) {
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok { if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
n.roomIDToJoinedUsers[roomID] = make(userIDSet) n.roomIDToJoinedUsers[roomID] = newUserIDSet(8)
} }
n.roomIDToJoinedUsers[roomID].remove(userID) n.roomIDToJoinedUsers[roomID].remove(userID)
n.roomIDToJoinedUsers[roomID].precompute()
} }
func (n *Notifier) JoinedUsers(roomID string) (userIDs []string) { func (n *Notifier) JoinedUsers(roomID string) (userIDs []string) {
@ -521,19 +527,52 @@ func (n *Notifier) _removeEmptyUserStreams() {
} }
// A string set, mainly existing for improving clarity of structs in this file. // A string set, mainly existing for improving clarity of structs in this file.
type userIDSet map[string]struct{} type userIDSet struct {
sync.Mutex
func (s userIDSet) add(str string) { set map[string]struct{}
s[str] = struct{}{} precomputed []string
} }
func (s userIDSet) remove(str string) { func newUserIDSet(cap int) *userIDSet {
delete(s, str) return &userIDSet{
set: make(map[string]struct{}, cap),
precomputed: nil,
}
} }
func (s userIDSet) values() (vals []string) { func (s *userIDSet) add(str string) {
vals = make([]string, 0, len(s)) s.Lock()
for str := range s { defer s.Unlock()
s.set[str] = struct{}{}
s.precomputed = s.precomputed[:0] // invalidate cache
}
func (s *userIDSet) remove(str string) {
s.Lock()
defer s.Unlock()
delete(s.set, str)
s.precomputed = s.precomputed[:0] // invalidate cache
}
func (s *userIDSet) precompute() {
s.Lock()
defer s.Unlock()
s.precomputed = s.values()
}
func (s *userIDSet) isIn(str string) bool {
s.Lock()
defer s.Unlock()
_, ok := s.set[str]
return ok
}
func (s *userIDSet) values() (vals []string) {
if len(s.precomputed) > 0 {
return s.precomputed // only return if not invalidated
}
vals = make([]string, 0, len(s.set))
for str := range s.set {
vals = append(vals, str) vals = append(vals, str)
} }
return return

View file

@ -262,12 +262,8 @@ func (r *messagesReq) retrieveEvents() (
clientEvents []gomatrixserverlib.ClientEvent, start, clientEvents []gomatrixserverlib.ClientEvent, start,
end types.TopologyToken, err error, end types.TopologyToken, err error,
) { ) {
eventFilter := r.filter
// Retrieve the events from the local database. // Retrieve the events from the local database.
streamEvents, err := r.db.GetEventsInTopologicalRange( streamEvents, err := r.db.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering)
r.ctx, r.from, r.to, r.roomID, eventFilter.Limit, r.backwardOrdering,
)
if err != nil { if err != nil {
err = fmt.Errorf("GetEventsInRange: %w", err) err = fmt.Errorf("GetEventsInRange: %w", err)
return return

View file

@ -105,7 +105,7 @@ type Database interface {
// Returns an error if there was a problem communicating with the database. // Returns an error if there was a problem communicating with the database.
DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error) DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error)
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last. // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last.
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error)
// EventPositionInTopology returns the depth and stream position of the given event. // EventPositionInTopology returns the depth and stream position of the given event.
EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error) EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
// BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events. // BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events.

View file

@ -81,6 +81,15 @@ const insertEventSQL = "" +
const selectEventsSQL = "" + const selectEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)"
const selectEventsWithFilterSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" +
" AND ( $2::text[] IS NULL OR sender = ANY($2) )" +
" AND ( $3::text[] IS NULL OR NOT(sender = ANY($3)) )" +
" AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" +
" AND ( $5::text[] IS NULL OR NOT(type LIKE ANY($5)) )" +
" AND ( $6::bool IS NULL OR contains_url = $6 )" +
" LIMIT $7"
const selectRecentEventsSQL = "" + const selectRecentEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" + " WHERE room_id = $1 AND id > $2 AND id <= $3" +
@ -153,6 +162,7 @@ const selectContextAfterEventSQL = "" +
type outputRoomEventsStatements struct { type outputRoomEventsStatements struct {
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt selectEventsStmt *sql.Stmt
selectEventsWitFilterStmt *sql.Stmt
selectMaxEventIDStmt *sql.Stmt selectMaxEventIDStmt *sql.Stmt
selectRecentEventsStmt *sql.Stmt selectRecentEventsStmt *sql.Stmt
selectRecentEventsForSyncStmt *sql.Stmt selectRecentEventsForSyncStmt *sql.Stmt
@ -174,6 +184,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertEventStmt, insertEventSQL}, {&s.insertEventStmt, insertEventSQL},
{&s.selectEventsStmt, selectEventsSQL}, {&s.selectEventsStmt, selectEventsSQL},
{&s.selectEventsWitFilterStmt, selectEventsWithFilterSQL},
{&s.selectMaxEventIDStmt, selectMaxEventIDSQL}, {&s.selectMaxEventIDStmt, selectMaxEventIDSQL},
{&s.selectRecentEventsStmt, selectRecentEventsSQL}, {&s.selectRecentEventsStmt, selectRecentEventsSQL},
{&s.selectRecentEventsForSyncStmt, selectRecentEventsForSyncSQL}, {&s.selectRecentEventsForSyncStmt, selectRecentEventsForSyncSQL},
@ -310,7 +321,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
// Parse content as JSON and search for an "url" key // Parse content as JSON and search for an "url" key
containsURL := false containsURL := false
var content map[string]interface{} var content map[string]interface{}
if json.Unmarshal(event.Content(), &content) != nil { if json.Unmarshal(event.Content(), &content) == nil {
// Set containsURL to true if url is present // Set containsURL to true if url is present
_, containsURL = content["url"] _, containsURL = content["url"]
} }
@ -429,10 +440,29 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
// selectEvents returns the events for the given event IDs. If an event is // selectEvents returns the events for the given event IDs. If an event is
// missing from the database, it will be omitted. // missing from the database, it will be omitted.
func (s *outputRoomEventsStatements) SelectEvents( func (s *outputRoomEventsStatements) SelectEvents(
ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool, ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool,
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectEventsStmt) var (
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) stmt *sql.Stmt
rows *sql.Rows
err error
)
if filter == nil {
stmt = sqlutil.TxStmt(txn, s.selectEventsStmt)
rows, err = stmt.QueryContext(ctx, pq.StringArray(eventIDs))
} else {
senders, notSenders := getSendersRoomEventFilter(filter)
stmt = sqlutil.TxStmt(txn, s.selectEventsWitFilterStmt)
rows, err = stmt.QueryContext(ctx,
pq.StringArray(eventIDs),
pq.StringArray(senders),
pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
filter.ContainsURL,
filter.Limit,
)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -150,7 +150,7 @@ func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, stre
// Returns an error if there was a problem talking with the database. // Returns an error if there was a problem talking with the database.
// Does not include any transaction IDs in the returned events. // Does not include any transaction IDs in the returned events.
func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, false) streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, nil, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -312,7 +312,7 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e
// Check if we have all of the event's previous events. If an event is // Check if we have all of the event's previous events. If an event is
// missing, add it to the room's backward extremities. // missing, add it to the room's backward extremities.
prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), false) prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), nil, false)
if err != nil { if err != nil {
return err return err
} }
@ -429,7 +429,8 @@ func (d *Database) updateRoomState(
func (d *Database) GetEventsInTopologicalRange( func (d *Database) GetEventsInTopologicalRange(
ctx context.Context, ctx context.Context,
from, to *types.TopologyToken, from, to *types.TopologyToken,
roomID string, limit int, roomID string,
filter *gomatrixserverlib.RoomEventFilter,
backwardOrdering bool, backwardOrdering bool,
) (events []types.StreamEvent, err error) { ) (events []types.StreamEvent, err error) {
var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition
@ -450,14 +451,14 @@ func (d *Database) GetEventsInTopologicalRange(
// Select the event IDs from the defined range. // Select the event IDs from the defined range.
var eIDs []string var eIDs []string
eIDs, err = d.Topology.SelectEventIDsInRange( eIDs, err = d.Topology.SelectEventIDsInRange(
ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, limit, !backwardOrdering, ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering,
) )
if err != nil { if err != nil {
return return
} }
// Retrieve the events' contents using their IDs. // Retrieve the events' contents using their IDs.
events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs, true) events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs, filter, true)
return return
} }
@ -619,7 +620,7 @@ func (d *Database) fetchMissingStateEvents(
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, error) {
// Fetch from the events table first so we pick up the stream ID for the // Fetch from the events table first so we pick up the stream ID for the
// event. // event.
events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, false) events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, nil, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -104,8 +104,7 @@ func (s *accountDataStatements) SelectAccountDataInRange(
}, },
filter.Senders, filter.NotSenders, filter.Senders, filter.NotSenders,
filter.Types, filter.NotTypes, filter.Types, filter.NotTypes,
[]string{}, filter.Limit, FilterOrderAsc, []string{}, nil, filter.Limit, FilterOrderAsc)
)
rows, err := stmt.QueryContext(ctx, params...) rows, err := stmt.QueryContext(ctx, params...)
if err != nil { if err != nil {

View file

@ -220,7 +220,7 @@ func (s *currentRoomStateStatements) SelectCurrentState(
}, },
stateFilter.Senders, stateFilter.NotSenders, stateFilter.Senders, stateFilter.NotSenders,
stateFilter.Types, stateFilter.NotTypes, stateFilter.Types, stateFilter.NotTypes,
excludeEventIDs, stateFilter.Limit, FilterOrderNone, excludeEventIDs, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderNone,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("s.prepareWithFilters: %w", err) return nil, fmt.Errorf("s.prepareWithFilters: %w", err)

View file

@ -26,7 +26,7 @@ const (
func prepareWithFilters( func prepareWithFilters(
db *sql.DB, txn *sql.Tx, query string, params []interface{}, db *sql.DB, txn *sql.Tx, query string, params []interface{},
senders, notsenders, types, nottypes *[]string, excludeEventIDs []string, senders, notsenders, types, nottypes *[]string, excludeEventIDs []string,
limit int, order FilterOrder, containsURL *bool, limit int, order FilterOrder,
) (*sql.Stmt, []interface{}, error) { ) (*sql.Stmt, []interface{}, error) {
offset := len(params) offset := len(params)
if senders != nil { if senders != nil {
@ -69,6 +69,9 @@ func prepareWithFilters(
query += ` AND type NOT = ""` query += ` AND type NOT = ""`
} }
} }
if containsURL != nil {
query += fmt.Sprintf(" AND contains_url = %v", *containsURL)
}
if count := len(excludeEventIDs); count > 0 { if count := len(excludeEventIDs); count > 0 {
query += " AND event_id NOT IN " + sqlutil.QueryVariadicOffset(count, offset) query += " AND event_id NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range excludeEventIDs { for _, v := range excludeEventIDs {

View file

@ -168,7 +168,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
s.db, txn, stmtSQL, inputParams, s.db, txn, stmtSQL, inputParams,
stateFilter.Senders, stateFilter.NotSenders, stateFilter.Senders, stateFilter.NotSenders,
stateFilter.Types, stateFilter.NotTypes, stateFilter.Types, stateFilter.NotTypes,
nil, stateFilter.Limit, FilterOrderAsc, nil, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderAsc,
) )
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err) return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err)
@ -277,7 +277,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
// Parse content as JSON and search for an "url" key // Parse content as JSON and search for an "url" key
containsURL := false containsURL := false
var content map[string]interface{} var content map[string]interface{}
if json.Unmarshal(event.Content(), &content) != nil { if json.Unmarshal(event.Content(), &content) == nil {
// Set containsURL to true if url is present // Set containsURL to true if url is present
_, containsURL = content["url"] _, containsURL = content["url"]
} }
@ -345,7 +345,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
}, },
eventFilter.Senders, eventFilter.NotSenders, eventFilter.Senders, eventFilter.NotSenders,
eventFilter.Types, eventFilter.NotTypes, eventFilter.Types, eventFilter.NotTypes,
nil, eventFilter.Limit+1, FilterOrderDesc, nil, eventFilter.ContainsURL, eventFilter.Limit+1, FilterOrderDesc,
) )
if err != nil { if err != nil {
return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err) return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err)
@ -393,7 +393,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
}, },
eventFilter.Senders, eventFilter.NotSenders, eventFilter.Senders, eventFilter.NotSenders,
eventFilter.Types, eventFilter.NotTypes, eventFilter.Types, eventFilter.NotTypes,
nil, eventFilter.Limit, FilterOrderAsc, nil, eventFilter.ContainsURL, eventFilter.Limit, FilterOrderAsc,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("s.prepareWithFilters: %w", err) return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
@ -419,20 +419,27 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
// selectEvents returns the events for the given event IDs. If an event is // selectEvents returns the events for the given event IDs. If an event is
// missing from the database, it will be omitted. // missing from the database, it will be omitted.
func (s *outputRoomEventsStatements) SelectEvents( func (s *outputRoomEventsStatements) SelectEvents(
ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool, ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool,
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, error) {
iEventIDs := make([]interface{}, len(eventIDs)) iEventIDs := make([]interface{}, len(eventIDs))
for i := range eventIDs { for i := range eventIDs {
iEventIDs[i] = eventIDs[i] iEventIDs[i] = eventIDs[i]
} }
selectSQL := strings.Replace(selectEventsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1) selectSQL := strings.Replace(selectEventsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1)
var rows *sql.Rows
var err error if filter == nil {
if txn != nil { filter = &gomatrixserverlib.RoomEventFilter{Limit: 20}
rows, err = txn.QueryContext(ctx, selectSQL, iEventIDs...)
} else {
rows, err = s.db.QueryContext(ctx, selectSQL, iEventIDs...)
} }
stmt, params, err := prepareWithFilters(
s.db, txn, selectSQL, iEventIDs,
filter.Senders, filter.NotSenders,
filter.Types, filter.NotTypes,
nil, filter.ContainsURL, filter.Limit, FilterOrderAsc,
)
if err != nil {
return nil, err
}
rows, err := stmt.QueryContext(ctx, params...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -527,7 +534,7 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
}, },
filter.Senders, filter.NotSenders, filter.Senders, filter.NotSenders,
filter.Types, filter.NotTypes, filter.Types, filter.NotTypes,
nil, filter.Limit, FilterOrderDesc, nil, filter.ContainsURL, filter.Limit, FilterOrderDesc,
) )
rows, err := stmt.QueryContext(ctx, params...) rows, err := stmt.QueryContext(ctx, params...)
@ -563,7 +570,7 @@ func (s *outputRoomEventsStatements) SelectContextAfterEvent(
}, },
filter.Senders, filter.NotSenders, filter.Senders, filter.NotSenders,
filter.Types, filter.NotTypes, filter.Types, filter.NotTypes,
nil, filter.Limit, FilterOrderAsc, nil, filter.ContainsURL, filter.Limit, FilterOrderAsc,
) )
rows, err := stmt.QueryContext(ctx, params...) rows, err := stmt.QueryContext(ctx, params...)

View file

@ -180,7 +180,8 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
to := types.TopologyToken{} to := types.TopologyToken{}
// backpaginate 5 messages starting at the latest position. // backpaginate 5 messages starting at the latest position.
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, 5, true) filter := &gomatrixserverlib.RoomEventFilter{Limit: 5}
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true)
if err != nil { if err != nil {
t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err) t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
} }

View file

@ -59,7 +59,7 @@ type Events interface {
SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
// SelectEarlyEvents returns the earliest events in the given room. // SelectEarlyEvents returns the earliest events in the given room.
SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error) SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error)
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool) ([]types.StreamEvent, error) SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool) ([]types.StreamEvent, error)
UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error
// DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely. // DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely.
DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error) DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)

View file

@ -13,6 +13,7 @@ import (
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3" "github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
) )
func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events, *sql.DB, func()) { func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events, *sql.DB, func()) {
@ -61,7 +62,7 @@ func TestOutputRoomEventsTable(t *testing.T) {
wantEventIDs := []string{ wantEventIDs := []string{
events[2].EventID(), events[0].EventID(), events[3].EventID(), events[1].EventID(), events[2].EventID(), events[0].EventID(), events[3].EventID(), events[1].EventID(),
} }
gotEvents, err := tab.SelectEvents(ctx, txn, wantEventIDs, true) gotEvents, err := tab.SelectEvents(ctx, txn, wantEventIDs, nil, true)
if err != nil { if err != nil {
return fmt.Errorf("failed to SelectEvents: %s", err) return fmt.Errorf("failed to SelectEvents: %s", err)
} }
@ -73,6 +74,28 @@ func TestOutputRoomEventsTable(t *testing.T) {
return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventIDs) return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventIDs)
} }
// Test that contains_url is correctly populated
urlEv := room.CreateEvent(t, alice, "m.text", map[string]interface{}{
"body": "test.txt",
"url": "mxc://test.txt",
})
if _, err = tab.InsertEvent(ctx, txn, urlEv, nil, nil, nil, false); err != nil {
return fmt.Errorf("failed to InsertEvent: %s", err)
}
wantEventID := []string{urlEv.EventID()}
t := true
gotEvents, err = tab.SelectEvents(ctx, txn, wantEventID, &gomatrixserverlib.RoomEventFilter{Limit: 1, ContainsURL: &t}, true)
if err != nil {
return fmt.Errorf("failed to SelectEvents: %s", err)
}
gotEventIDs = make([]string, len(gotEvents))
for i := range gotEvents {
gotEventIDs[i] = gotEvents[i].EventID()
}
if !reflect.DeepEqual(gotEventIDs, wantEventID) {
return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventID)
}
return nil return nil
}) })
if err != nil { if err != nil {

View file

@ -699,6 +699,7 @@ Ignore invite in full sync
Ignore invite in incremental sync Ignore invite in incremental sync
A filtered timeline reaches its limit A filtered timeline reaches its limit
A change to displayname should not result in a full state sync A change to displayname should not result in a full state sync
Can fetch images in room
The only membership state included in an initial sync is for all the senders in the timeline The only membership state included in an initial sync is for all the senders in the timeline
The only membership state included in an incremental sync is for senders in the timeline The only membership state included in an incremental sync is for senders in the timeline
Old members are included in gappy incr LL sync if they start speaking Old members are included in gappy incr LL sync if they start speaking