Apply filters correctly

This commit is contained in:
Till Faelligen 2022-04-08 11:51:19 +02:00
parent f299f97e0a
commit f564547d30
7 changed files with 83 additions and 38 deletions

View file

@ -60,7 +60,11 @@ func Context(
Headers: nil, Headers: nil,
} }
} }
filter.Rooms = append(filter.Rooms, roomID) if filter.Rooms != nil {
roomsFilter := *filter.Rooms
roomsFilter = append(roomsFilter, roomID)
filter.Rooms = &roomsFilter
}
ctx := req.Context() ctx := req.Context()
membershipRes := roomserver.QueryMembershipForUserResponse{} membershipRes := roomserver.QueryMembershipForUserResponse{}

View file

@ -233,9 +233,10 @@ func (s *currentRoomStateStatements) SelectCurrentState(
excludeEventIDs []string, excludeEventIDs []string,
) ([]*gomatrixserverlib.HeaderedEvent, error) { ) ([]*gomatrixserverlib.HeaderedEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt) stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt)
senders, notSenders := getSendersStateFilterFilter(stateFilter)
rows, err := stmt.QueryContext(ctx, roomID, rows, err := stmt.QueryContext(ctx, roomID,
pq.StringArray(stateFilter.Senders), pq.StringArray(senders),
pq.StringArray(stateFilter.NotSenders), pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)),
stateFilter.ContainsURL, stateFilter.ContainsURL,

View file

@ -16,21 +16,45 @@ package postgres
import ( import (
"strings" "strings"
"github.com/matrix-org/gomatrixserverlib"
) )
// filterConvertWildcardToSQL converts wildcards as defined in // filterConvertWildcardToSQL converts wildcards as defined in
// https://matrix.org/docs/spec/client_server/r0.3.0.html#post-matrix-client-r0-user-userid-filter // https://matrix.org/docs/spec/client_server/r0.3.0.html#post-matrix-client-r0-user-userid-filter
// to SQL wildcards that can be used with LIKE() // to SQL wildcards that can be used with LIKE()
func filterConvertTypeWildcardToSQL(values []string) []string { func filterConvertTypeWildcardToSQL(values *[]string) []string {
if values == nil { if values == nil {
// Return nil instead of []string{} so IS NULL can work correctly when // Return nil instead of []string{} so IS NULL can work correctly when
// the return value is passed into SQL queries // the return value is passed into SQL queries
return nil return nil
} }
ret := make([]string, len(values)) v := *values
for i := range values { ret := make([]string, len(v))
ret[i] = strings.Replace(values[i], "*", "%", -1) for i := range v {
ret[i] = strings.Replace(v[i], "*", "%", -1)
} }
return ret return ret
} }
// TODO: Replace when Dendrite uses Go 1.18
func getSendersRoomEventFilter(filter *gomatrixserverlib.RoomEventFilter) (senders []string, notSenders []string) {
if filter.Senders != nil {
senders = *filter.Senders
}
if filter.NotSenders != nil {
notSenders = *filter.NotSenders
}
return senders, notSenders
}
func getSendersStateFilterFilter(filter *gomatrixserverlib.StateFilter) (senders []string, notSenders []string) {
if filter.Senders != nil {
senders = *filter.Senders
}
if filter.NotSenders != nil {
notSenders = *filter.NotSenders
}
return senders, notSenders
}

View file

@ -204,11 +204,11 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
stateFilter *gomatrixserverlib.StateFilter, roomIDs []string, stateFilter *gomatrixserverlib.StateFilter, roomIDs []string,
) (map[string]map[string]bool, map[string]types.StreamEvent, error) { ) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt) stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt)
senders, notSenders := getSendersStateFilterFilter(stateFilter)
rows, err := stmt.QueryContext( rows, err := stmt.QueryContext(
ctx, r.Low(), r.High(), pq.StringArray(roomIDs), ctx, r.Low(), r.High(), pq.StringArray(roomIDs),
pq.StringArray(stateFilter.Senders), pq.StringArray(senders),
pq.StringArray(stateFilter.NotSenders), pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)),
stateFilter.ContainsURL, stateFilter.ContainsURL,
@ -353,10 +353,11 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
} else { } else {
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt) stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt)
} }
senders, notSenders := getSendersRoomEventFilter(eventFilter)
rows, err := stmt.QueryContext( rows, err := stmt.QueryContext(
ctx, roomID, r.Low(), r.High(), ctx, roomID, r.Low(), r.High(),
pq.StringArray(eventFilter.Senders), pq.StringArray(senders),
pq.StringArray(eventFilter.NotSenders), pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)), pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)),
eventFilter.Limit+1, eventFilter.Limit+1,
@ -398,11 +399,12 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, error) {
senders, notSenders := getSendersRoomEventFilter(eventFilter)
stmt := sqlutil.TxStmt(txn, s.selectEarlyEventsStmt) stmt := sqlutil.TxStmt(txn, s.selectEarlyEventsStmt)
rows, err := stmt.QueryContext( rows, err := stmt.QueryContext(
ctx, roomID, r.Low(), r.High(), ctx, roomID, r.Low(), r.High(),
pq.StringArray(eventFilter.Senders), pq.StringArray(senders),
pq.StringArray(eventFilter.NotSenders), pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)), pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)),
eventFilter.Limit, eventFilter.Limit,
@ -462,10 +464,11 @@ func (s *outputRoomEventsStatements) SelectContextEvent(ctx context.Context, txn
func (s *outputRoomEventsStatements) SelectContextBeforeEvent( func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter, ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter,
) (evts []*gomatrixserverlib.HeaderedEvent, err error) { ) (evts []*gomatrixserverlib.HeaderedEvent, err error) {
senders, notSenders := getSendersRoomEventFilter(filter)
rows, err := sqlutil.TxStmt(txn, s.selectContextBeforeEventStmt).QueryContext( rows, err := sqlutil.TxStmt(txn, s.selectContextBeforeEventStmt).QueryContext(
ctx, roomID, id, filter.Limit, ctx, roomID, id, filter.Limit,
pq.StringArray(filter.Senders), pq.StringArray(senders),
pq.StringArray(filter.NotSenders), pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)), pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
) )
@ -494,10 +497,11 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
func (s *outputRoomEventsStatements) SelectContextAfterEvent( func (s *outputRoomEventsStatements) SelectContextAfterEvent(
ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter, ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter,
) (lastID int, evts []*gomatrixserverlib.HeaderedEvent, err error) { ) (lastID int, evts []*gomatrixserverlib.HeaderedEvent, err error) {
senders, notSenders := getSendersRoomEventFilter(filter)
rows, err := sqlutil.TxStmt(txn, s.selectContextAfterEventStmt).QueryContext( rows, err := sqlutil.TxStmt(txn, s.selectContextAfterEventStmt).QueryContext(
ctx, roomID, id, filter.Limit, ctx, roomID, id, filter.Limit,
pq.StringArray(filter.Senders), pq.StringArray(senders),
pq.StringArray(filter.NotSenders), pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)), pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
) )

View file

@ -119,12 +119,12 @@ func (s *accountDataStatements) SelectAccountDataInRange(
// and positional parameters makes the query annoyingly hard to do, it's easier // and positional parameters makes the query annoyingly hard to do, it's easier
// and clearer to do it in Go-land. If there are no filters for [not]types then // and clearer to do it in Go-land. If there are no filters for [not]types then
// this gets skipped. // this gets skipped.
for _, includeType := range accountDataFilterPart.Types { for _, includeType := range *accountDataFilterPart.Types {
if includeType != dataType { // TODO: wildcard support if includeType != dataType { // TODO: wildcard support
continue continue
} }
} }
for _, excludeType := range accountDataFilterPart.NotTypes { for _, excludeType := range *accountDataFilterPart.NotTypes {
if excludeType == dataType { // TODO: wildcard support if excludeType == dataType { // TODO: wildcard support
continue continue
} }

View file

@ -25,34 +25,42 @@ const (
// parts. // parts.
func prepareWithFilters( func prepareWithFilters(
db *sql.DB, txn *sql.Tx, query string, params []interface{}, db *sql.DB, txn *sql.Tx, query string, params []interface{},
senders, notsenders, types, nottypes []string, excludeEventIDs []string, senders, notsenders, types, nottypes *[]string, excludeEventIDs []string,
limit int, order FilterOrder, limit int, order FilterOrder,
) (*sql.Stmt, []interface{}, error) { ) (*sql.Stmt, []interface{}, error) {
offset := len(params) offset := len(params)
if count := len(senders); count > 0 { if senders != nil {
if count := len(*senders); count > 0 {
query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset) query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range senders { for _, v := range *senders {
params, offset = append(params, v), offset+1 params, offset = append(params, v), offset+1
} }
} }
if count := len(notsenders); count > 0 { }
if notsenders != nil {
if count := len(*notsenders); count > 0 {
query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset) query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range notsenders { for _, v := range *notsenders {
params, offset = append(params, v), offset+1 params, offset = append(params, v), offset+1
} }
} }
if count := len(types); count > 0 { }
if types != nil {
if count := len(*types); count > 0 {
query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset) query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range types { for _, v := range *types {
params, offset = append(params, v), offset+1 params, offset = append(params, v), offset+1
} }
} }
if count := len(nottypes); count > 0 { }
if nottypes != nil {
if count := len(*nottypes); count > 0 {
query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset) query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range nottypes { for _, v := range *nottypes {
params, offset = append(params, v), offset+1 params, offset = append(params, v), offset+1
} }
} }
}
if count := len(excludeEventIDs); count > 0 { if count := len(excludeEventIDs); count > 0 {
query += " AND event_id NOT IN " + sqlutil.QueryVariadicOffset(count, offset) query += " AND event_id NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range excludeEventIDs { for _, v := range excludeEventIDs {

View file

@ -423,8 +423,12 @@ func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *ty
return err return err
} }
req.IgnoredUsers = *ignores req.IgnoredUsers = *ignores
var userList []string = nil
for userID := range ignores.List { for userID := range ignores.List {
eventFilter.NotSenders = append(eventFilter.NotSenders, userID) userList = append(userList, userID)
}
if len(userList) > 0 {
eventFilter.NotSenders = &userList
} }
return nil return nil
} }