Fix filtering in current state table

This commit is contained in:
Neil Alexander 2021-01-19 16:49:49 +00:00
parent f8aea35292
commit 6959339e1f
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
4 changed files with 94 additions and 91 deletions

View file

@ -19,6 +19,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt"
"strings" "strings"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
@ -66,13 +67,8 @@ const selectRoomIDsWithMembershipSQL = "" +
"SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2"
const selectCurrentStateSQL = "" + const selectCurrentStateSQL = "" +
"SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" + "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1"
" AND ( $2 IS NULL OR sender IN ($2) )" + // WHEN, ORDER BY and LIMIT will be added by prepareWithFilter
" AND ( $3 IS NULL OR NOT(sender IN ($3)) )" +
" AND ( $4 IS NULL OR type IN ($4) )" +
" AND ( $5 IS NULL OR NOT(type IN ($5)) )" +
" AND ( $6 IS NULL OR contains_url = $6 )" +
" LIMIT $7"
const selectJoinedUsersSQL = "" + const selectJoinedUsersSQL = "" +
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'" "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'"
@ -95,7 +91,6 @@ type currentRoomStateStatements struct {
deleteRoomStateByEventIDStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt
DeleteRoomStateForRoomStmt *sql.Stmt DeleteRoomStateForRoomStmt *sql.Stmt
selectRoomIDsWithMembershipStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt
selectCurrentStateStmt *sql.Stmt
selectJoinedUsersStmt *sql.Stmt selectJoinedUsersStmt *sql.Stmt
selectStateEventStmt *sql.Stmt selectStateEventStmt *sql.Stmt
} }
@ -121,9 +116,6 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (t
if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil {
return nil, err return nil, err
} }
if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil {
return nil, err
}
if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
return nil, err return nil, err
} }
@ -185,17 +177,20 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(
// CurrentState returns all the current state events for the given room. // CurrentState returns all the current state events for the given room.
func (s *currentRoomStateStatements) SelectCurrentState( func (s *currentRoomStateStatements) SelectCurrentState(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
stateFilterPart *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
) ([]*gomatrixserverlib.HeaderedEvent, error) { ) ([]*gomatrixserverlib.HeaderedEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt) stmt, params, err := prepareWithFilters(
rows, err := stmt.QueryContext(ctx, roomID, s.db, selectCurrentStateSQL,
nil, // FIXME: pq.StringArray(stateFilterPart.Senders), []interface{}{},
nil, // FIXME: pq.StringArray(stateFilterPart.NotSenders), stateFilter.Senders, stateFilter.NotSenders,
nil, // FIXME: pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.Types)), stateFilter.Types, stateFilter.NotTypes,
nil, // FIXME: pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.NotTypes)), stateFilter.Limit, "",
stateFilterPart.ContainsURL,
stateFilterPart.Limit,
) )
if err != nil {
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
}
rows, err := stmt.QueryContext(ctx, params...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -0,0 +1,63 @@
package sqlite3
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/sirupsen/logrus"
)
// prepareWithFilters returns a prepared statement with the
// relevant filters included. It also includes an []interface{}
// list of all the relevant parameters to pass straight to
// QueryContext, QueryRowContext etc.
// We don't take the filter object directly here because the
// fields might come from either a StateFilter or an EventFilter,
// and it's easier just to have the caller extract the relevant
// parts.
func prepareWithFilters(
db *sql.DB, query string, params []interface{},
senders, notsenders, types, nottypes []string,
limit int, order string,
) (*sql.Stmt, []interface{}, error) {
offset := len(params)
if count := len(senders); count > 0 {
query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range senders {
params, offset = append(params, v), offset+1
}
}
if count := len(notsenders); count > 0 {
query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range notsenders {
params, offset = append(params, v), offset+1
}
}
if count := len(types); count > 0 {
query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range types {
params, offset = append(params, v), offset+1
}
}
if count := len(nottypes); count > 0 {
query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range nottypes {
params, offset = append(params, v), offset+1
}
}
if order != "" {
query += " ORDER BY id " + order
}
query += fmt.Sprintf(" LIMIT $%d", offset+1)
params = append(params, limit)
logrus.Infof("QUERY: %s", query)
logrus.Infof("PARAMS: %v", params)
stmt, err := db.Prepare(query)
if err != nil {
return nil, nil, fmt.Errorf("s.db.Prepare: %w", err)
}
return stmt, params, nil
}

View file

@ -21,7 +21,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"sort" "sort"
"strings"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
@ -30,7 +29,6 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -63,18 +61,18 @@ const selectEventsSQL = "" +
const selectRecentEventsSQL = "" + const selectRecentEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" + " WHERE room_id = $1 AND id > $2 AND id <= $3"
" $FILTERS" // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectRecentEventsForSyncSQL = "" + const selectRecentEventsForSyncSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" + " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE"
" $FILTERS" // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectEarlyEventsSQL = "" + const selectEarlyEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" + " WHERE room_id = $1 AND id > $2 AND id <= $3"
" $FILTERS" // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectMaxEventIDSQL = "" + const selectMaxEventIDSQL = "" +
"SELECT MAX(id) FROM syncapi_output_room_events" "SELECT MAX(id) FROM syncapi_output_room_events"
@ -86,8 +84,8 @@ const selectStateInRangeSQL = "" +
"SELECT id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" + "SELECT id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" +
" FROM syncapi_output_room_events" + " FROM syncapi_output_room_events" +
" WHERE (id > $1 AND id <= $2)" + " WHERE (id > $1 AND id <= $2)" +
" AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))" + " AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))"
" $FILTERS" // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const deleteEventsForRoomSQL = "" + const deleteEventsForRoomSQL = "" +
"DELETE FROM syncapi_output_room_events WHERE room_id = $1" "DELETE FROM syncapi_output_room_events WHERE room_id = $1"
@ -129,61 +127,6 @@ func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Even
return s, nil return s, nil
} }
// prepareWithFilters returns a prepared statement with the
// relevant filters included. It also includes an []interface{}
// list of all the relevant parameters to pass straight to
// QueryContext, QueryRowContext etc.
// We don't take the filter object directly here because the
// fields might come from either a StateFilter or an EventFilter,
// and it's easier just to have the caller extract the relevant
// parts.
func (s *outputRoomEventsStatements) prepareWithFilters(
query string, params []interface{},
senders, notsenders, types, nottypes []string,
limit int, order string,
) (*sql.Stmt, []interface{}, error) {
filters := ""
offset := len(params)
if count := len(senders); count > 0 {
filters += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range senders {
params, offset = append(params, v), offset+1
}
}
if count := len(notsenders); count > 0 {
filters += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range notsenders {
params, offset = append(params, v), offset+1
}
}
if count := len(types); count > 0 {
filters += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range types {
params, offset = append(params, v), offset+1
}
}
if count := len(nottypes); count > 0 {
filters += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range nottypes {
params, offset = append(params, v), offset+1
}
}
filters += " ORDER BY id " + order
filters += fmt.Sprintf(" LIMIT $%d", offset+1)
params = append(params, limit)
query = strings.Replace(query, " $FILTERS", filters, 1)
logrus.Infof("QUERY: %s", query)
logrus.Infof("PARAMS: %v", params)
stmt, err := s.db.Prepare(query)
if err != nil {
return nil, nil, fmt.Errorf("s.db.Prepare: %w", err)
}
return stmt, params, nil
}
func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error { func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error {
headeredJSON, err := json.Marshal(event) headeredJSON, err := json.Marshal(event)
if err != nil { if err != nil {
@ -200,8 +143,8 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
ctx context.Context, txn *sql.Tx, r types.Range, ctx context.Context, txn *sql.Tx, r types.Range,
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
) (map[string]map[string]bool, map[string]types.StreamEvent, error) { ) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
stmt, params, err := s.prepareWithFilters( stmt, params, err := prepareWithFilters(
selectStateInRangeSQL, s.db, selectStateInRangeSQL,
[]interface{}{ []interface{}{
r.Low(), r.High(), r.Low(), r.High(),
}, },
@ -373,8 +316,8 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
query = selectRecentEventsSQL query = selectRecentEventsSQL
} }
stmt, params, err := s.prepareWithFilters( stmt, params, err := prepareWithFilters(
query, s.db, query,
[]interface{}{ []interface{}{
roomID, r.Low(), r.High(), roomID, r.Low(), r.High(),
}, },
@ -421,8 +364,8 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, error) {
stmt, params, err := s.prepareWithFilters( stmt, params, err := prepareWithFilters(
selectEarlyEventsSQL, s.db, selectEarlyEventsSQL,
[]interface{}{ []interface{}{
roomID, r.Low(), r.High(), roomID, r.Low(), r.High(),
}, },

View file

@ -62,6 +62,8 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
} }
if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err == nil { if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err == nil {
filter = *f filter = *f
} else {
panic(err)
} }
} }
} }