Try to use request filter

This commit is contained in:
Neil Alexander 2021-01-19 10:12:19 +00:00
parent a47b008d36
commit a86c9f7a0c
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
10 changed files with 21 additions and 32 deletions

View file

@ -367,7 +367,6 @@ func newTestSyncRequest(userID, deviceID string, since types.StreamingToken) typ
Timeout: 1 * time.Minute, Timeout: 1 * time.Minute,
Since: since, Since: since,
WantFullState: false, WantFullState: false,
Limit: 20,
Log: util.GetLogger(context.TODO()), Log: util.GetLogger(context.TODO()),
Context: context.TODO(), Context: context.TODO(),
} }

View file

@ -235,7 +235,7 @@ func (r *messagesReq) retrieveEvents() (
clientEvents []gomatrixserverlib.ClientEvent, start, clientEvents []gomatrixserverlib.ClientEvent, start,
end types.TopologyToken, err error, end types.TopologyToken, err error,
) { ) {
eventFilter := gomatrixserverlib.DefaultEventFilter() eventFilter := gomatrixserverlib.DefaultRoomEventFilter()
eventFilter.Limit = r.limit eventFilter.Limit = r.limit
// Retrieve the events from the local database. // Retrieve the events from the local database.

View file

@ -40,7 +40,7 @@ type Database interface {
GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error)
RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.EventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
GetBackwardTopologyPos(ctx context.Context, events []types.StreamEvent) (types.TopologyToken, error) GetBackwardTopologyPos(ctx context.Context, events []types.StreamEvent) (types.TopologyToken, error)
PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error)
@ -105,7 +105,7 @@ type Database interface {
// Returns an error if there was a problem communicating with the database. // Returns an error if there was a problem communicating with the database.
DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error) DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error)
// GetEventsInStreamingRange retrieves all of the events on a given ordering using the given extremities and limit. // GetEventsInStreamingRange retrieves all of the events on a given ordering using the given extremities and limit.
GetEventsInStreamingRange(ctx context.Context, from, to *types.StreamingToken, roomID string, eventFilter *gomatrixserverlib.EventFilter, backwardOrdering bool) (events []types.StreamEvent, err error) GetEventsInStreamingRange(ctx context.Context, from, to *types.StreamingToken, roomID string, eventFilter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error)
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit.
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error)
// EventPositionInTopology returns the depth and stream position of the given event. // EventPositionInTopology returns the depth and stream position of the given event.

View file

@ -330,7 +330,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
// from sync. // from sync.
func (s *outputRoomEventsStatements) SelectRecentEvents( func (s *outputRoomEventsStatements) SelectRecentEvents(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, r types.Range, eventFilter *gomatrixserverlib.EventFilter, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
chronologicalOrder bool, onlySyncEvents bool, chronologicalOrder bool, onlySyncEvents bool,
) ([]types.StreamEvent, bool, error) { ) ([]types.StreamEvent, bool, error) {
var stmt *sql.Stmt var stmt *sql.Stmt

View file

@ -110,7 +110,7 @@ func (d *Database) RoomIDsWithMembership(ctx context.Context, userID string, mem
return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership) return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership)
} }
func (d *Database) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.EventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) { func (d *Database) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) {
return d.OutputEvents.SelectRecentEvents(ctx, nil, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents) return d.OutputEvents.SelectRecentEvents(ctx, nil, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents)
} }
@ -151,7 +151,7 @@ func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixse
func (d *Database) GetEventsInStreamingRange( func (d *Database) GetEventsInStreamingRange(
ctx context.Context, ctx context.Context,
from, to *types.StreamingToken, from, to *types.StreamingToken,
roomID string, eventFilter *gomatrixserverlib.EventFilter, roomID string, eventFilter *gomatrixserverlib.RoomEventFilter,
backwardOrdering bool, backwardOrdering bool,
) (events []types.StreamEvent, err error) { ) (events []types.StreamEvent, err error) {
r := types.Range{ r := types.Range{

View file

@ -336,7 +336,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
func (s *outputRoomEventsStatements) SelectRecentEvents( func (s *outputRoomEventsStatements) SelectRecentEvents(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, r types.Range, eventFilter *gomatrixserverlib.EventFilter, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
chronologicalOrder bool, onlySyncEvents bool, chronologicalOrder bool, onlySyncEvents bool,
) ([]types.StreamEvent, bool, error) { ) ([]types.StreamEvent, bool, error) {
var stmt *sql.Stmt var stmt *sql.Stmt

View file

@ -56,7 +56,7 @@ type Events interface {
// SelectRecentEvents returns events between the two stream positions: exclusive of low and inclusive of high. // SelectRecentEvents returns events between the two stream positions: exclusive of low and inclusive of high.
// If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude from sync. // If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude from sync.
// Returns up to `limit` events. Returns `limited=true` if there are more events in this range but we hit the `limit`. // Returns up to `limit` events. Returns `limited=true` if there are more events in this range but we hit the `limit`.
SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.EventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
// SelectEarlyEvents returns the earliest events in the given room. // SelectEarlyEvents returns the earliest events in the given room.
SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, limit int) ([]types.StreamEvent, error) SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, limit int) ([]types.StreamEvent, error)
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error) SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error)

View file

@ -48,9 +48,8 @@ func (p *PDUStreamProvider) CompleteSync(
return from return from
} }
stateFilter := gomatrixserverlib.DefaultStateFilter() // TODO: use filter provided in request stateFilter := req.Filter.Room.State
eventFilter := gomatrixserverlib.DefaultEventFilter() eventFilter := req.Filter.Room.Timeline
eventFilter.Limit = req.Limit
// Build up a /sync response. Add joined rooms. // Build up a /sync response. Add joined rooms.
for _, roomID := range joinedRoomIDs { for _, roomID := range joinedRoomIDs {
@ -106,10 +105,8 @@ func (p *PDUStreamProvider) IncrementalSync(
var stateDeltas []types.StateDelta var stateDeltas []types.StateDelta
var joinedRooms []string var joinedRooms []string
// TODO: use filter provided in request stateFilter := req.Filter.Room.State
stateFilter := gomatrixserverlib.DefaultStateFilter() eventFilter := req.Filter.Room.Timeline
eventFilter := gomatrixserverlib.DefaultEventFilter()
eventFilter.Limit = req.Limit
if req.WantFullState { if req.WantFullState {
if stateDeltas, joinedRooms, err = p.DB.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { if stateDeltas, joinedRooms, err = p.DB.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil {
@ -142,7 +139,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
device *userapi.Device, device *userapi.Device,
r types.Range, r types.Range,
delta types.StateDelta, delta types.StateDelta,
eventFilter *gomatrixserverlib.EventFilter, eventFilter *gomatrixserverlib.RoomEventFilter,
res *types.Response, res *types.Response,
) error { ) error {
if delta.MembershipPos > 0 && delta.Membership == gomatrixserverlib.Leave { if delta.MembershipPos > 0 && delta.Membership == gomatrixserverlib.Leave {
@ -213,7 +210,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
roomID string, roomID string,
r types.Range, r types.Range,
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
eventFilter *gomatrixserverlib.EventFilter, eventFilter *gomatrixserverlib.RoomEventFilter,
device *userapi.Device, device *userapi.Device,
) (jr *types.JoinResponse, err error) { ) (jr *types.JoinResponse, err error) {
var stateEvents []*gomatrixserverlib.HeaderedEvent var stateEvents []*gomatrixserverlib.HeaderedEvent

View file

@ -16,6 +16,7 @@ package sync
import ( import (
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"strconv" "strconv"
"time" "time"
@ -51,16 +52,14 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
return nil, err return nil, err
} }
} }
timelineLimit := DefaultTimelineLimit
// TODO: read from stored filters too // TODO: read from stored filters too
filter := gomatrixserverlib.DefaultFilter()
filterQuery := req.URL.Query().Get("filter") filterQuery := req.URL.Query().Get("filter")
if filterQuery != "" { if filterQuery != "" {
if filterQuery[0] == '{' { if filterQuery[0] == '{' {
// attempt to parse the timeline limit at least // attempt to parse the timeline limit at least
var f filter if err := json.Unmarshal([]byte(filterQuery), &filter); err != nil {
err := json.Unmarshal([]byte(filterQuery), &f) return nil, fmt.Errorf("json.Unmarshal: %w", err)
if err == nil && f.Room.Timeline.Limit != nil {
timelineLimit = *f.Room.Timeline.Limit
} }
} else { } else {
// attempt to load the filter ID // attempt to load the filter ID
@ -71,21 +70,17 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
} }
f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery) f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery)
if err == nil { if err == nil {
timelineLimit = f.Room.Timeline.Limit filter = *f
} }
} }
} }
filter := gomatrixserverlib.DefaultEventFilter()
filter.Limit = timelineLimit
// TODO: Additional query params: set_presence, filter
logger := util.GetLogger(req.Context()).WithFields(logrus.Fields{ logger := util.GetLogger(req.Context()).WithFields(logrus.Fields{
"user_id": device.UserID, "user_id": device.UserID,
"device_id": device.ID, "device_id": device.ID,
"since": since, "since": since,
"timeout": timeout, "timeout": timeout,
"limit": timelineLimit, "limit": filter.Room.Timeline.Limit,
}) })
return &types.SyncRequest{ return &types.SyncRequest{
@ -96,7 +91,6 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
Filter: filter, // Filter: filter, //
Since: since, // Since: since, //
Timeout: timeout, // Timeout: timeout, //
Limit: timelineLimit, //
Rooms: make(map[string]string), // Populated by the PDU stream Rooms: make(map[string]string), // Populated by the PDU stream
WantFullState: wantFullState, // WantFullState: wantFullState, //
}, nil }, nil

View file

@ -14,9 +14,8 @@ type SyncRequest struct {
Log *logrus.Entry Log *logrus.Entry
Device *userapi.Device Device *userapi.Device
Response *Response Response *Response
Filter gomatrixserverlib.EventFilter Filter gomatrixserverlib.Filter
Since StreamingToken Since StreamingToken
Limit int
Timeout time.Duration Timeout time.Duration
WantFullState bool WantFullState bool