mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-26 00:03:09 -06:00
Fix filtering in current state table
This commit is contained in:
parent
f8aea35292
commit
6959339e1f
|
|
@ -19,6 +19,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
|
@ -66,13 +67,8 @@ const selectRoomIDsWithMembershipSQL = "" +
|
||||||
"SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2"
|
"SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2"
|
||||||
|
|
||||||
const selectCurrentStateSQL = "" +
|
const selectCurrentStateSQL = "" +
|
||||||
"SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" +
|
"SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1"
|
||||||
" AND ( $2 IS NULL OR sender IN ($2) )" +
|
// WHEN, ORDER BY and LIMIT will be added by prepareWithFilter
|
||||||
" AND ( $3 IS NULL OR NOT(sender IN ($3)) )" +
|
|
||||||
" AND ( $4 IS NULL OR type IN ($4) )" +
|
|
||||||
" AND ( $5 IS NULL OR NOT(type IN ($5)) )" +
|
|
||||||
" AND ( $6 IS NULL OR contains_url = $6 )" +
|
|
||||||
" LIMIT $7"
|
|
||||||
|
|
||||||
const selectJoinedUsersSQL = "" +
|
const selectJoinedUsersSQL = "" +
|
||||||
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'"
|
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'"
|
||||||
|
|
@ -95,7 +91,6 @@ type currentRoomStateStatements struct {
|
||||||
deleteRoomStateByEventIDStmt *sql.Stmt
|
deleteRoomStateByEventIDStmt *sql.Stmt
|
||||||
DeleteRoomStateForRoomStmt *sql.Stmt
|
DeleteRoomStateForRoomStmt *sql.Stmt
|
||||||
selectRoomIDsWithMembershipStmt *sql.Stmt
|
selectRoomIDsWithMembershipStmt *sql.Stmt
|
||||||
selectCurrentStateStmt *sql.Stmt
|
|
||||||
selectJoinedUsersStmt *sql.Stmt
|
selectJoinedUsersStmt *sql.Stmt
|
||||||
selectStateEventStmt *sql.Stmt
|
selectStateEventStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
@ -121,9 +116,6 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (t
|
||||||
if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil {
|
if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
|
if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -185,17 +177,20 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(
|
||||||
// CurrentState returns all the current state events for the given room.
|
// CurrentState returns all the current state events for the given room.
|
||||||
func (s *currentRoomStateStatements) SelectCurrentState(
|
func (s *currentRoomStateStatements) SelectCurrentState(
|
||||||
ctx context.Context, txn *sql.Tx, roomID string,
|
ctx context.Context, txn *sql.Tx, roomID string,
|
||||||
stateFilterPart *gomatrixserverlib.StateFilter,
|
stateFilter *gomatrixserverlib.StateFilter,
|
||||||
) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt)
|
stmt, params, err := prepareWithFilters(
|
||||||
rows, err := stmt.QueryContext(ctx, roomID,
|
s.db, selectCurrentStateSQL,
|
||||||
nil, // FIXME: pq.StringArray(stateFilterPart.Senders),
|
[]interface{}{},
|
||||||
nil, // FIXME: pq.StringArray(stateFilterPart.NotSenders),
|
stateFilter.Senders, stateFilter.NotSenders,
|
||||||
nil, // FIXME: pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.Types)),
|
stateFilter.Types, stateFilter.NotTypes,
|
||||||
nil, // FIXME: pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.NotTypes)),
|
stateFilter.Limit, "",
|
||||||
stateFilterPart.ContainsURL,
|
|
||||||
stateFilterPart.Limit,
|
|
||||||
)
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := stmt.QueryContext(ctx, params...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
63
syncapi/storage/sqlite3/filtering.go
Normal file
63
syncapi/storage/sqlite3/filtering.go
Normal file
|
|
@ -0,0 +1,63 @@
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// prepareWithFilters returns a prepared statement with the
|
||||||
|
// relevant filters included. It also includes an []interface{}
|
||||||
|
// list of all the relevant parameters to pass straight to
|
||||||
|
// QueryContext, QueryRowContext etc.
|
||||||
|
// We don't take the filter object directly here because the
|
||||||
|
// fields might come from either a StateFilter or an EventFilter,
|
||||||
|
// and it's easier just to have the caller extract the relevant
|
||||||
|
// parts.
|
||||||
|
func prepareWithFilters(
|
||||||
|
db *sql.DB, query string, params []interface{},
|
||||||
|
senders, notsenders, types, nottypes []string,
|
||||||
|
limit int, order string,
|
||||||
|
) (*sql.Stmt, []interface{}, error) {
|
||||||
|
offset := len(params)
|
||||||
|
if count := len(senders); count > 0 {
|
||||||
|
query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||||
|
for _, v := range senders {
|
||||||
|
params, offset = append(params, v), offset+1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if count := len(notsenders); count > 0 {
|
||||||
|
query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||||
|
for _, v := range notsenders {
|
||||||
|
params, offset = append(params, v), offset+1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if count := len(types); count > 0 {
|
||||||
|
query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||||
|
for _, v := range types {
|
||||||
|
params, offset = append(params, v), offset+1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if count := len(nottypes); count > 0 {
|
||||||
|
query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||||
|
for _, v := range nottypes {
|
||||||
|
params, offset = append(params, v), offset+1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if order != "" {
|
||||||
|
query += " ORDER BY id " + order
|
||||||
|
}
|
||||||
|
query += fmt.Sprintf(" LIMIT $%d", offset+1)
|
||||||
|
params = append(params, limit)
|
||||||
|
|
||||||
|
logrus.Infof("QUERY: %s", query)
|
||||||
|
logrus.Infof("PARAMS: %v", params)
|
||||||
|
|
||||||
|
stmt, err := db.Prepare(query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("s.db.Prepare: %w", err)
|
||||||
|
}
|
||||||
|
return stmt, params, nil
|
||||||
|
}
|
||||||
|
|
@ -21,7 +21,6 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
|
@ -30,7 +29,6 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -63,18 +61,18 @@ const selectEventsSQL = "" +
|
||||||
|
|
||||||
const selectRecentEventsSQL = "" +
|
const selectRecentEventsSQL = "" +
|
||||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
|
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
|
||||||
" WHERE room_id = $1 AND id > $2 AND id <= $3" +
|
" WHERE room_id = $1 AND id > $2 AND id <= $3"
|
||||||
" $FILTERS"
|
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
|
||||||
|
|
||||||
const selectRecentEventsForSyncSQL = "" +
|
const selectRecentEventsForSyncSQL = "" +
|
||||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
|
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
|
||||||
" WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" +
|
" WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE"
|
||||||
" $FILTERS"
|
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
|
||||||
|
|
||||||
const selectEarlyEventsSQL = "" +
|
const selectEarlyEventsSQL = "" +
|
||||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
|
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
|
||||||
" WHERE room_id = $1 AND id > $2 AND id <= $3" +
|
" WHERE room_id = $1 AND id > $2 AND id <= $3"
|
||||||
" $FILTERS"
|
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
|
||||||
|
|
||||||
const selectMaxEventIDSQL = "" +
|
const selectMaxEventIDSQL = "" +
|
||||||
"SELECT MAX(id) FROM syncapi_output_room_events"
|
"SELECT MAX(id) FROM syncapi_output_room_events"
|
||||||
|
|
@ -86,8 +84,8 @@ const selectStateInRangeSQL = "" +
|
||||||
"SELECT id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" +
|
"SELECT id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" +
|
||||||
" FROM syncapi_output_room_events" +
|
" FROM syncapi_output_room_events" +
|
||||||
" WHERE (id > $1 AND id <= $2)" +
|
" WHERE (id > $1 AND id <= $2)" +
|
||||||
" AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))" +
|
" AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))"
|
||||||
" $FILTERS"
|
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
|
||||||
|
|
||||||
const deleteEventsForRoomSQL = "" +
|
const deleteEventsForRoomSQL = "" +
|
||||||
"DELETE FROM syncapi_output_room_events WHERE room_id = $1"
|
"DELETE FROM syncapi_output_room_events WHERE room_id = $1"
|
||||||
|
|
@ -129,61 +127,6 @@ func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Even
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepareWithFilters returns a prepared statement with the
|
|
||||||
// relevant filters included. It also includes an []interface{}
|
|
||||||
// list of all the relevant parameters to pass straight to
|
|
||||||
// QueryContext, QueryRowContext etc.
|
|
||||||
// We don't take the filter object directly here because the
|
|
||||||
// fields might come from either a StateFilter or an EventFilter,
|
|
||||||
// and it's easier just to have the caller extract the relevant
|
|
||||||
// parts.
|
|
||||||
func (s *outputRoomEventsStatements) prepareWithFilters(
|
|
||||||
query string, params []interface{},
|
|
||||||
senders, notsenders, types, nottypes []string,
|
|
||||||
limit int, order string,
|
|
||||||
) (*sql.Stmt, []interface{}, error) {
|
|
||||||
filters := ""
|
|
||||||
offset := len(params)
|
|
||||||
if count := len(senders); count > 0 {
|
|
||||||
filters += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset)
|
|
||||||
for _, v := range senders {
|
|
||||||
params, offset = append(params, v), offset+1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if count := len(notsenders); count > 0 {
|
|
||||||
filters += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
|
|
||||||
for _, v := range notsenders {
|
|
||||||
params, offset = append(params, v), offset+1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if count := len(types); count > 0 {
|
|
||||||
filters += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset)
|
|
||||||
for _, v := range types {
|
|
||||||
params, offset = append(params, v), offset+1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if count := len(nottypes); count > 0 {
|
|
||||||
filters += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
|
|
||||||
for _, v := range nottypes {
|
|
||||||
params, offset = append(params, v), offset+1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
filters += " ORDER BY id " + order
|
|
||||||
filters += fmt.Sprintf(" LIMIT $%d", offset+1)
|
|
||||||
params = append(params, limit)
|
|
||||||
|
|
||||||
query = strings.Replace(query, " $FILTERS", filters, 1)
|
|
||||||
|
|
||||||
logrus.Infof("QUERY: %s", query)
|
|
||||||
logrus.Infof("PARAMS: %v", params)
|
|
||||||
|
|
||||||
stmt, err := s.db.Prepare(query)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("s.db.Prepare: %w", err)
|
|
||||||
}
|
|
||||||
return stmt, params, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error {
|
func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error {
|
||||||
headeredJSON, err := json.Marshal(event)
|
headeredJSON, err := json.Marshal(event)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -200,8 +143,8 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
|
||||||
ctx context.Context, txn *sql.Tx, r types.Range,
|
ctx context.Context, txn *sql.Tx, r types.Range,
|
||||||
stateFilter *gomatrixserverlib.StateFilter,
|
stateFilter *gomatrixserverlib.StateFilter,
|
||||||
) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
|
) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
|
||||||
stmt, params, err := s.prepareWithFilters(
|
stmt, params, err := prepareWithFilters(
|
||||||
selectStateInRangeSQL,
|
s.db, selectStateInRangeSQL,
|
||||||
[]interface{}{
|
[]interface{}{
|
||||||
r.Low(), r.High(),
|
r.Low(), r.High(),
|
||||||
},
|
},
|
||||||
|
|
@ -373,8 +316,8 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
|
||||||
query = selectRecentEventsSQL
|
query = selectRecentEventsSQL
|
||||||
}
|
}
|
||||||
|
|
||||||
stmt, params, err := s.prepareWithFilters(
|
stmt, params, err := prepareWithFilters(
|
||||||
query,
|
s.db, query,
|
||||||
[]interface{}{
|
[]interface{}{
|
||||||
roomID, r.Low(), r.High(),
|
roomID, r.Low(), r.High(),
|
||||||
},
|
},
|
||||||
|
|
@ -421,8 +364,8 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
|
roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
|
||||||
) ([]types.StreamEvent, error) {
|
) ([]types.StreamEvent, error) {
|
||||||
stmt, params, err := s.prepareWithFilters(
|
stmt, params, err := prepareWithFilters(
|
||||||
selectEarlyEventsSQL,
|
s.db, selectEarlyEventsSQL,
|
||||||
[]interface{}{
|
[]interface{}{
|
||||||
roomID, r.Low(), r.High(),
|
roomID, r.Low(), r.High(),
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -62,6 +62,8 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
|
||||||
}
|
}
|
||||||
if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err == nil {
|
if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err == nil {
|
||||||
filter = *f
|
filter = *f
|
||||||
|
} else {
|
||||||
|
panic(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue