diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 36ba3a3e6..519aeff68 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -262,12 +262,8 @@ func (r *messagesReq) retrieveEvents() ( clientEvents []gomatrixserverlib.ClientEvent, start, end types.TopologyToken, err error, ) { - eventFilter := r.filter - // Retrieve the events from the local database. - streamEvents, err := r.db.GetEventsInTopologicalRange( - r.ctx, r.from, r.to, r.roomID, eventFilter.Limit, r.backwardOrdering, - ) + streamEvents, err := r.db.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering) if err != nil { err = fmt.Errorf("GetEventsInRange: %w", err) return diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index cf3fd5532..14cb08a52 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -105,7 +105,7 @@ type Database interface { // Returns an error if there was a problem communicating with the database. DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error) // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last. - GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) + GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error) // EventPositionInTopology returns the depth and stream position of the given event. EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error) // BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events. diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 269cd4494..17e2feab6 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -81,6 +81,15 @@ const insertEventSQL = "" + const selectEventsSQL = "" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" +const selectEventsWithFilterSQL = "" + + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" + + " AND ( $2::text[] IS NULL OR sender = ANY($2) )" + + " AND ( $3::text[] IS NULL OR NOT(sender = ANY($3)) )" + + " AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" + + " AND ( $5::text[] IS NULL OR NOT(type LIKE ANY($5)) )" + + " AND ( $6::bool IS NULL OR contains_url = $6 )" + + " LIMIT $7" + const selectRecentEventsSQL = "" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + " WHERE room_id = $1 AND id > $2 AND id <= $3" + @@ -153,6 +162,7 @@ const selectContextAfterEventSQL = "" + type outputRoomEventsStatements struct { insertEventStmt *sql.Stmt selectEventsStmt *sql.Stmt + selectEventsWitFilterStmt *sql.Stmt selectMaxEventIDStmt *sql.Stmt selectRecentEventsStmt *sql.Stmt selectRecentEventsForSyncStmt *sql.Stmt @@ -174,6 +184,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { return s, sqlutil.StatementList{ {&s.insertEventStmt, insertEventSQL}, {&s.selectEventsStmt, selectEventsSQL}, + {&s.selectEventsWitFilterStmt, selectEventsWithFilterSQL}, {&s.selectMaxEventIDStmt, selectMaxEventIDSQL}, {&s.selectRecentEventsStmt, selectRecentEventsSQL}, {&s.selectRecentEventsForSyncStmt, selectRecentEventsForSyncSQL}, @@ -310,7 +321,7 @@ func (s *outputRoomEventsStatements) InsertEvent( // Parse content as JSON and search for an "url" key containsURL := false var content map[string]interface{} - if json.Unmarshal(event.Content(), &content) != nil { + if json.Unmarshal(event.Content(), &content) == nil { // Set containsURL to true if url is present _, containsURL = content["url"] } @@ -429,10 +440,29 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( // selectEvents returns the events for the given event IDs. If an event is // missing from the database, it will be omitted. func (s *outputRoomEventsStatements) SelectEvents( - ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool, + ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool, ) ([]types.StreamEvent, error) { - stmt := sqlutil.TxStmt(txn, s.selectEventsStmt) - rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) + var ( + stmt *sql.Stmt + rows *sql.Rows + err error + ) + if filter == nil { + stmt = sqlutil.TxStmt(txn, s.selectEventsStmt) + rows, err = stmt.QueryContext(ctx, pq.StringArray(eventIDs)) + } else { + senders, notSenders := getSendersRoomEventFilter(filter) + stmt = sqlutil.TxStmt(txn, s.selectEventsWitFilterStmt) + rows, err = stmt.QueryContext(ctx, + pq.StringArray(eventIDs), + pq.StringArray(senders), + pq.StringArray(notSenders), + pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)), + pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)), + filter.ContainsURL, + filter.Limit, + ) + } if err != nil { return nil, err } diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 14db5795c..91eba44e1 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -150,7 +150,7 @@ func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, stre // Returns an error if there was a problem talking with the database. // Does not include any transaction IDs in the returned events. func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { - streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, false) + streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, nil, false) if err != nil { return nil, err } @@ -312,7 +312,7 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e // Check if we have all of the event's previous events. If an event is // missing, add it to the room's backward extremities. - prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), false) + prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), nil, false) if err != nil { return err } @@ -429,7 +429,8 @@ func (d *Database) updateRoomState( func (d *Database) GetEventsInTopologicalRange( ctx context.Context, from, to *types.TopologyToken, - roomID string, limit int, + roomID string, + filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool, ) (events []types.StreamEvent, err error) { var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition @@ -450,14 +451,14 @@ func (d *Database) GetEventsInTopologicalRange( // Select the event IDs from the defined range. var eIDs []string eIDs, err = d.Topology.SelectEventIDsInRange( - ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, limit, !backwardOrdering, + ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering, ) if err != nil { return } // Retrieve the events' contents using their IDs. - events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs, true) + events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs, filter, true) return } @@ -619,7 +620,7 @@ func (d *Database) fetchMissingStateEvents( ) ([]types.StreamEvent, error) { // Fetch from the events table first so we pick up the stream ID for the // event. - events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, false) + events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, nil, false) if err != nil { return nil, err } diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index b0aeb70f2..71a098177 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -104,8 +104,7 @@ func (s *accountDataStatements) SelectAccountDataInRange( }, filter.Senders, filter.NotSenders, filter.Types, filter.NotTypes, - []string{}, filter.Limit, FilterOrderAsc, - ) + []string{}, nil, filter.Limit, FilterOrderAsc) rows, err := stmt.QueryContext(ctx, params...) if err != nil { diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 464f32e04..ccda005c1 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -220,7 +220,7 @@ func (s *currentRoomStateStatements) SelectCurrentState( }, stateFilter.Senders, stateFilter.NotSenders, stateFilter.Types, stateFilter.NotTypes, - excludeEventIDs, stateFilter.Limit, FilterOrderNone, + excludeEventIDs, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderNone, ) if err != nil { return nil, fmt.Errorf("s.prepareWithFilters: %w", err) diff --git a/syncapi/storage/sqlite3/filtering.go b/syncapi/storage/sqlite3/filtering.go index 54b12ddf8..05edb7b8c 100644 --- a/syncapi/storage/sqlite3/filtering.go +++ b/syncapi/storage/sqlite3/filtering.go @@ -26,7 +26,7 @@ const ( func prepareWithFilters( db *sql.DB, txn *sql.Tx, query string, params []interface{}, senders, notsenders, types, nottypes *[]string, excludeEventIDs []string, - limit int, order FilterOrder, + containsURL *bool, limit int, order FilterOrder, ) (*sql.Stmt, []interface{}, error) { offset := len(params) if senders != nil { @@ -69,6 +69,9 @@ func prepareWithFilters( query += ` AND type NOT = ""` } } + if containsURL != nil { + query += fmt.Sprintf(" AND contains_url = %v", *containsURL) + } if count := len(excludeEventIDs); count > 0 { query += " AND event_id NOT IN " + sqlutil.QueryVariadicOffset(count, offset) for _, v := range excludeEventIDs { diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 9da9d776e..188f7582b 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -168,7 +168,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange( s.db, txn, stmtSQL, inputParams, stateFilter.Senders, stateFilter.NotSenders, stateFilter.Types, stateFilter.NotTypes, - nil, stateFilter.Limit, FilterOrderAsc, + nil, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderAsc, ) if err != nil { return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err) @@ -277,7 +277,7 @@ func (s *outputRoomEventsStatements) InsertEvent( // Parse content as JSON and search for an "url" key containsURL := false var content map[string]interface{} - if json.Unmarshal(event.Content(), &content) != nil { + if json.Unmarshal(event.Content(), &content) == nil { // Set containsURL to true if url is present _, containsURL = content["url"] } @@ -345,7 +345,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( }, eventFilter.Senders, eventFilter.NotSenders, eventFilter.Types, eventFilter.NotTypes, - nil, eventFilter.Limit+1, FilterOrderDesc, + nil, eventFilter.ContainsURL, eventFilter.Limit+1, FilterOrderDesc, ) if err != nil { return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err) @@ -393,7 +393,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( }, eventFilter.Senders, eventFilter.NotSenders, eventFilter.Types, eventFilter.NotTypes, - nil, eventFilter.Limit, FilterOrderAsc, + nil, eventFilter.ContainsURL, eventFilter.Limit, FilterOrderAsc, ) if err != nil { return nil, fmt.Errorf("s.prepareWithFilters: %w", err) @@ -419,20 +419,27 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( // selectEvents returns the events for the given event IDs. If an event is // missing from the database, it will be omitted. func (s *outputRoomEventsStatements) SelectEvents( - ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool, + ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool, ) ([]types.StreamEvent, error) { iEventIDs := make([]interface{}, len(eventIDs)) for i := range eventIDs { iEventIDs[i] = eventIDs[i] } selectSQL := strings.Replace(selectEventsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1) - var rows *sql.Rows - var err error - if txn != nil { - rows, err = txn.QueryContext(ctx, selectSQL, iEventIDs...) - } else { - rows, err = s.db.QueryContext(ctx, selectSQL, iEventIDs...) + + if filter == nil { + filter = &gomatrixserverlib.RoomEventFilter{Limit: 20} } + stmt, params, err := prepareWithFilters( + s.db, txn, selectSQL, iEventIDs, + filter.Senders, filter.NotSenders, + filter.Types, filter.NotTypes, + nil, filter.ContainsURL, filter.Limit, FilterOrderAsc, + ) + if err != nil { + return nil, err + } + rows, err := stmt.QueryContext(ctx, params...) if err != nil { return nil, err } @@ -527,7 +534,7 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent( }, filter.Senders, filter.NotSenders, filter.Types, filter.NotTypes, - nil, filter.Limit, FilterOrderDesc, + nil, filter.ContainsURL, filter.Limit, FilterOrderDesc, ) rows, err := stmt.QueryContext(ctx, params...) @@ -563,7 +570,7 @@ func (s *outputRoomEventsStatements) SelectContextAfterEvent( }, filter.Senders, filter.NotSenders, filter.Types, filter.NotTypes, - nil, filter.Limit, FilterOrderAsc, + nil, filter.ContainsURL, filter.Limit, FilterOrderAsc, ) rows, err := stmt.QueryContext(ctx, params...) diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 4e1634ece..15bb769a2 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -180,7 +180,8 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) { to := types.TopologyToken{} // backpaginate 5 messages starting at the latest position. - paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, 5, true) + filter := &gomatrixserverlib.RoomEventFilter{Limit: 5} + paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true) if err != nil { t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err) } diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index a7df70248..993e2022b 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -59,7 +59,7 @@ type Events interface { SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) // SelectEarlyEvents returns the earliest events in the given room. SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error) - SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool) ([]types.StreamEvent, error) + SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool) ([]types.StreamEvent, error) UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error // DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely. DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error) diff --git a/syncapi/storage/tables/output_room_events_test.go b/syncapi/storage/tables/output_room_events_test.go index 7a81ffcd2..a143e5ecd 100644 --- a/syncapi/storage/tables/output_room_events_test.go +++ b/syncapi/storage/tables/output_room_events_test.go @@ -13,6 +13,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" ) func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events, *sql.DB, func()) { @@ -61,7 +62,7 @@ func TestOutputRoomEventsTable(t *testing.T) { wantEventIDs := []string{ events[2].EventID(), events[0].EventID(), events[3].EventID(), events[1].EventID(), } - gotEvents, err := tab.SelectEvents(ctx, txn, wantEventIDs, true) + gotEvents, err := tab.SelectEvents(ctx, txn, wantEventIDs, nil, true) if err != nil { return fmt.Errorf("failed to SelectEvents: %s", err) } @@ -73,6 +74,28 @@ func TestOutputRoomEventsTable(t *testing.T) { return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventIDs) } + // Test that contains_url is correctly populated + urlEv := room.CreateEvent(t, alice, "m.text", map[string]interface{}{ + "body": "test.txt", + "url": "mxc://test.txt", + }) + if _, err = tab.InsertEvent(ctx, txn, urlEv, nil, nil, nil, false); err != nil { + return fmt.Errorf("failed to InsertEvent: %s", err) + } + wantEventID := []string{urlEv.EventID()} + t := true + gotEvents, err = tab.SelectEvents(ctx, txn, wantEventID, &gomatrixserverlib.RoomEventFilter{Limit: 1, ContainsURL: &t}, true) + if err != nil { + return fmt.Errorf("failed to SelectEvents: %s", err) + } + gotEventIDs = make([]string, len(gotEvents)) + for i := range gotEvents { + gotEventIDs[i] = gotEvents[i].EventID() + } + if !reflect.DeepEqual(gotEventIDs, wantEventID) { + return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventID) + } + return nil }) if err != nil { diff --git a/sytest-whitelist b/sytest-whitelist index a7aea05ed..f63b96f52 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -698,4 +698,5 @@ Ignore user in existing room Ignore invite in full sync Ignore invite in incremental sync A filtered timeline reaches its limit -A change to displayname should not result in a full state sync \ No newline at end of file +A change to displayname should not result in a full state sync +Can fetch images in room \ No newline at end of file