diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index ac6590575..f6f0e0565 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -19,6 +19,7 @@ import ( "context" "database/sql" "encoding/json" + "fmt" "strings" "github.com/matrix-org/dendrite/internal" @@ -66,13 +67,8 @@ const selectRoomIDsWithMembershipSQL = "" + "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" const selectCurrentStateSQL = "" + - "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" + - " AND ( $2 IS NULL OR sender IN ($2) )" + - " AND ( $3 IS NULL OR NOT(sender IN ($3)) )" + - " AND ( $4 IS NULL OR type IN ($4) )" + - " AND ( $5 IS NULL OR NOT(type IN ($5)) )" + - " AND ( $6 IS NULL OR contains_url = $6 )" + - " LIMIT $7" + "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" + // WHEN, ORDER BY and LIMIT will be added by prepareWithFilter const selectJoinedUsersSQL = "" + "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'" @@ -95,7 +91,6 @@ type currentRoomStateStatements struct { deleteRoomStateByEventIDStmt *sql.Stmt DeleteRoomStateForRoomStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt - selectCurrentStateStmt *sql.Stmt selectJoinedUsersStmt *sql.Stmt selectStateEventStmt *sql.Stmt } @@ -121,9 +116,6 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (t if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { return nil, err } - if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil { - return nil, err - } if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { return nil, err } @@ -185,17 +177,20 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( // CurrentState returns all the current state events for the given room. func (s *currentRoomStateStatements) SelectCurrentState( ctx context.Context, txn *sql.Tx, roomID string, - stateFilterPart *gomatrixserverlib.StateFilter, + stateFilter *gomatrixserverlib.StateFilter, ) ([]*gomatrixserverlib.HeaderedEvent, error) { - stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt) - rows, err := stmt.QueryContext(ctx, roomID, - nil, // FIXME: pq.StringArray(stateFilterPart.Senders), - nil, // FIXME: pq.StringArray(stateFilterPart.NotSenders), - nil, // FIXME: pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.Types)), - nil, // FIXME: pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.NotTypes)), - stateFilterPart.ContainsURL, - stateFilterPart.Limit, + stmt, params, err := prepareWithFilters( + s.db, selectCurrentStateSQL, + []interface{}{}, + stateFilter.Senders, stateFilter.NotSenders, + stateFilter.Types, stateFilter.NotTypes, + stateFilter.Limit, "", ) + if err != nil { + return nil, fmt.Errorf("s.prepareWithFilters: %w", err) + } + + rows, err := stmt.QueryContext(ctx, params...) if err != nil { return nil, err } diff --git a/syncapi/storage/sqlite3/filtering.go b/syncapi/storage/sqlite3/filtering.go new file mode 100644 index 000000000..90def5f01 --- /dev/null +++ b/syncapi/storage/sqlite3/filtering.go @@ -0,0 +1,63 @@ +package sqlite3 + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/sirupsen/logrus" +) + +// prepareWithFilters returns a prepared statement with the +// relevant filters included. It also includes an []interface{} +// list of all the relevant parameters to pass straight to +// QueryContext, QueryRowContext etc. +// We don't take the filter object directly here because the +// fields might come from either a StateFilter or an EventFilter, +// and it's easier just to have the caller extract the relevant +// parts. +func prepareWithFilters( + db *sql.DB, query string, params []interface{}, + senders, notsenders, types, nottypes []string, + limit int, order string, +) (*sql.Stmt, []interface{}, error) { + offset := len(params) + if count := len(senders); count > 0 { + query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset) + for _, v := range senders { + params, offset = append(params, v), offset+1 + } + } + if count := len(notsenders); count > 0 { + query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset) + for _, v := range notsenders { + params, offset = append(params, v), offset+1 + } + } + if count := len(types); count > 0 { + query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset) + for _, v := range types { + params, offset = append(params, v), offset+1 + } + } + if count := len(nottypes); count > 0 { + query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset) + for _, v := range nottypes { + params, offset = append(params, v), offset+1 + } + } + if order != "" { + query += " ORDER BY id " + order + } + query += fmt.Sprintf(" LIMIT $%d", offset+1) + params = append(params, limit) + + logrus.Infof("QUERY: %s", query) + logrus.Infof("PARAMS: %v", params) + + stmt, err := db.Prepare(query) + if err != nil { + return nil, nil, fmt.Errorf("s.db.Prepare: %w", err) + } + return stmt, params, nil +} diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index d7cb85608..d02b449d0 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -21,7 +21,6 @@ import ( "encoding/json" "fmt" "sort" - "strings" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" @@ -30,7 +29,6 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" - "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" ) @@ -63,18 +61,18 @@ const selectEventsSQL = "" + const selectRecentEventsSQL = "" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + - " WHERE room_id = $1 AND id > $2 AND id <= $3" + - " $FILTERS" + " WHERE room_id = $1 AND id > $2 AND id <= $3" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters const selectRecentEventsForSyncSQL = "" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + - " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" + - " $FILTERS" + " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters const selectEarlyEventsSQL = "" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + - " WHERE room_id = $1 AND id > $2 AND id <= $3" + - " $FILTERS" + " WHERE room_id = $1 AND id > $2 AND id <= $3" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters const selectMaxEventIDSQL = "" + "SELECT MAX(id) FROM syncapi_output_room_events" @@ -86,8 +84,8 @@ const selectStateInRangeSQL = "" + "SELECT id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" + " FROM syncapi_output_room_events" + " WHERE (id > $1 AND id <= $2)" + - " AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))" + - " $FILTERS" + " AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters const deleteEventsForRoomSQL = "" + "DELETE FROM syncapi_output_room_events WHERE room_id = $1" @@ -129,61 +127,6 @@ func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Even return s, nil } -// prepareWithFilters returns a prepared statement with the -// relevant filters included. It also includes an []interface{} -// list of all the relevant parameters to pass straight to -// QueryContext, QueryRowContext etc. -// We don't take the filter object directly here because the -// fields might come from either a StateFilter or an EventFilter, -// and it's easier just to have the caller extract the relevant -// parts. -func (s *outputRoomEventsStatements) prepareWithFilters( - query string, params []interface{}, - senders, notsenders, types, nottypes []string, - limit int, order string, -) (*sql.Stmt, []interface{}, error) { - filters := "" - offset := len(params) - if count := len(senders); count > 0 { - filters += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset) - for _, v := range senders { - params, offset = append(params, v), offset+1 - } - } - if count := len(notsenders); count > 0 { - filters += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset) - for _, v := range notsenders { - params, offset = append(params, v), offset+1 - } - } - if count := len(types); count > 0 { - filters += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset) - for _, v := range types { - params, offset = append(params, v), offset+1 - } - } - if count := len(nottypes); count > 0 { - filters += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset) - for _, v := range nottypes { - params, offset = append(params, v), offset+1 - } - } - filters += " ORDER BY id " + order - filters += fmt.Sprintf(" LIMIT $%d", offset+1) - params = append(params, limit) - - query = strings.Replace(query, " $FILTERS", filters, 1) - - logrus.Infof("QUERY: %s", query) - logrus.Infof("PARAMS: %v", params) - - stmt, err := s.db.Prepare(query) - if err != nil { - return nil, nil, fmt.Errorf("s.db.Prepare: %w", err) - } - return stmt, params, nil -} - func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error { headeredJSON, err := json.Marshal(event) if err != nil { @@ -200,8 +143,8 @@ func (s *outputRoomEventsStatements) SelectStateInRange( ctx context.Context, txn *sql.Tx, r types.Range, stateFilter *gomatrixserverlib.StateFilter, ) (map[string]map[string]bool, map[string]types.StreamEvent, error) { - stmt, params, err := s.prepareWithFilters( - selectStateInRangeSQL, + stmt, params, err := prepareWithFilters( + s.db, selectStateInRangeSQL, []interface{}{ r.Low(), r.High(), }, @@ -373,8 +316,8 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( query = selectRecentEventsSQL } - stmt, params, err := s.prepareWithFilters( - query, + stmt, params, err := prepareWithFilters( + s.db, query, []interface{}{ roomID, r.Low(), r.High(), }, @@ -421,8 +364,8 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, ) ([]types.StreamEvent, error) { - stmt, params, err := s.prepareWithFilters( - selectEarlyEventsSQL, + stmt, params, err := prepareWithFilters( + s.db, selectEarlyEventsSQL, []interface{}{ roomID, r.Low(), r.High(), }, diff --git a/syncapi/sync/request.go b/syncapi/sync/request.go index 5ec0fbd69..58eb531c0 100644 --- a/syncapi/sync/request.go +++ b/syncapi/sync/request.go @@ -62,6 +62,8 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat } if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err == nil { filter = *f + } else { + panic(err) } } }