Replace SyncPosition with PaginationToken throughout syncapi

This commit is contained in:
Neil Alexander 2020-01-21 10:30:01 +00:00
parent 50f26a3002
commit 2541946c1f
16 changed files with 184 additions and 190 deletions

View file

@ -90,7 +90,7 @@ func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error
}).Panicf("could not save account data") }).Panicf("could not save account data")
} }
s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.SyncPosition{PDUPosition: pduPos}) s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.PaginationToken{PDUPosition: pduPos})
return nil return nil
} }

View file

@ -145,7 +145,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
}).Panicf("roomserver output log: write event failure") }).Panicf("roomserver output log: write event failure")
return nil return nil
} }
s.notifier.OnNewEvent(&ev, "", nil, types.SyncPosition{PDUPosition: pduPos}) s.notifier.OnNewEvent(&ev, "", nil, types.PaginationToken{PDUPosition: pduPos})
return nil return nil
} }
@ -162,7 +162,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
}).Panicf("roomserver output log: write invite failure") }).Panicf("roomserver output log: write invite failure")
return nil return nil
} }
s.notifier.OnNewEvent(&msg.Event, "", nil, types.SyncPosition{PDUPosition: pduPos}) s.notifier.OnNewEvent(&msg.Event, "", nil, types.PaginationToken{PDUPosition: pduPos})
return nil return nil
} }

View file

@ -63,7 +63,12 @@ func NewOutputTypingEventConsumer(
// Start consuming from typing api // Start consuming from typing api
func (s *OutputTypingEventConsumer) Start() error { func (s *OutputTypingEventConsumer) Start() error {
s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) { s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) {
s.notifier.OnNewEvent(nil, roomID, nil, types.SyncPosition{TypingPosition: latestSyncPosition}) s.notifier.OnNewEvent(
nil, roomID, nil,
types.PaginationToken{
EDUTypingPosition: types.StreamPosition(latestSyncPosition),
},
)
}) })
return s.typingConsumer.Start() return s.typingConsumer.Start()
@ -83,7 +88,7 @@ func (s *OutputTypingEventConsumer) onMessage(msg *sarama.ConsumerMessage) error
"typing": output.Event.Typing, "typing": output.Event.Typing,
}).Debug("received data from typing server") }).Debug("received data from typing server")
var typingPos int64 var typingPos types.StreamPosition
typingEvent := output.Event typingEvent := output.Event
if typingEvent.Typing { if typingEvent.Typing {
typingPos = s.db.AddTypingUser(typingEvent.UserID, typingEvent.RoomID, output.ExpireTime) typingPos = s.db.AddTypingUser(typingEvent.UserID, typingEvent.RoomID, output.ExpireTime)
@ -91,6 +96,6 @@ func (s *OutputTypingEventConsumer) onMessage(msg *sarama.ConsumerMessage) error
typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID) typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID)
} }
s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.SyncPosition{TypingPosition: typingPos}) s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.PaginationToken{EDUTypingPosition: typingPos})
return nil return nil
} }

View file

@ -236,10 +236,10 @@ func (r *messagesReq) retrieveEvents() (
// Generate pagination tokens to send to the client using the positions // Generate pagination tokens to send to the client using the positions
// retrieved previously. // retrieved previously.
start = types.NewPaginationTokenFromTypeAndPosition( start = types.NewPaginationTokenFromTypeAndPosition(
types.PaginationTokenTypeTopology, startPos, types.PaginationTokenTypeTopology, startPos, 0,
) )
end = types.NewPaginationTokenFromTypeAndPosition( end = types.NewPaginationTokenFromTypeAndPosition(
types.PaginationTokenTypeTopology, endPos, types.PaginationTokenTypeTopology, endPos, 0,
) )
if r.backwardOrdering { if r.backwardOrdering {
@ -248,13 +248,13 @@ func (r *messagesReq) retrieveEvents() (
// we consider a left to right chronological order), tokens need to refer // we consider a left to right chronological order), tokens need to refer
// to them by the event on their left, therefore we need to decrement the // to them by the event on their left, therefore we need to decrement the
// end position we send in the response if we're going backward. // end position we send in the response if we're going backward.
end.Position-- end.PDUPosition--
} }
// The lowest token value is 1, therefore we need to manually set it to that // The lowest token value is 1, therefore we need to manually set it to that
// value if we're below it. // value if we're below it.
if end.Position < types.StreamPosition(1) { if end.PDUPosition < types.StreamPosition(1) {
end.Position = types.StreamPosition(1) end.PDUPosition = types.StreamPosition(1)
} }
return return
@ -303,10 +303,10 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent
if r.wasToProvided { if r.wasToProvided {
// The condition in the SQL query is a strict "greater than" so // The condition in the SQL query is a strict "greater than" so
// we need to check against to-1. // we need to check against to-1.
isSetLargeEnough = (r.to.Position-1 == types.StreamPosition(streamEvents[len(streamEvents)-1].StreamPosition)) isSetLargeEnough = (r.to.PDUPosition-1 == types.StreamPosition(streamEvents[len(streamEvents)-1].StreamPosition))
} }
} else { } else {
isSetLargeEnough = (r.from.Position-1 == types.StreamPosition(streamEvents[0].StreamPosition)) isSetLargeEnough = (r.from.PDUPosition-1 == types.StreamPosition(streamEvents[0].StreamPosition))
} }
} }
@ -456,7 +456,7 @@ func setToDefault(
roomID string, roomID string,
) (to *types.PaginationToken, err error) { ) (to *types.PaginationToken, err error) {
if backwardOrdering { if backwardOrdering {
to = types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 1) to = types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 1, 0)
} else { } else {
var pos types.StreamPosition var pos types.StreamPosition
pos, err = db.MaxTopologicalPosition(ctx, roomID) pos, err = db.MaxTopologicalPosition(ctx, roomID)
@ -464,7 +464,7 @@ func setToDefault(
return return
} }
to = types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, pos) to = types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, pos, 0)
} }
return return

View file

@ -21,6 +21,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
) )
@ -89,7 +90,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) insertAccountData(
ctx context.Context, ctx context.Context,
userID, roomID, dataType string, userID, roomID, dataType string,
) (pos int64, err error) { ) (pos types.StreamPosition, err error) {
err = s.insertAccountDataStmt.QueryRowContext(ctx, userID, roomID, dataType).Scan(&pos) err = s.insertAccountDataStmt.QueryRowContext(ctx, userID, roomID, dataType).Scan(&pos)
return return
} }
@ -97,7 +98,7 @@ func (s *accountDataStatements) insertAccountData(
func (s *accountDataStatements) selectAccountDataInRange( func (s *accountDataStatements) selectAccountDataInRange(
ctx context.Context, ctx context.Context,
userID string, userID string,
oldPos, newPos int64, oldPos, newPos types.StreamPosition,
accountDataFilterPart *gomatrix.FilterPart, accountDataFilterPart *gomatrix.FilterPart,
) (data map[string][]string, err error) { ) (data map[string][]string, err error) {
data = make(map[string][]string) data = make(map[string][]string)

View file

@ -214,7 +214,7 @@ func (s *currentRoomStateStatements) deleteRoomStateByEventID(
func (s *currentRoomStateStatements) upsertRoomState( func (s *currentRoomStateStatements) upsertRoomState(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
event gomatrixserverlib.Event, membership *string, addedAt int64, event gomatrixserverlib.Event, membership *string, addedAt types.StreamPosition,
) error { ) error {
// Parse content as JSON and search for an "url" key // Parse content as JSON and search for an "url" key
containsURL := false containsURL := false

View file

@ -20,6 +20,7 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -86,7 +87,7 @@ func (s *inviteEventsStatements) prepare(db *sql.DB) (err error) {
func (s *inviteEventsStatements) insertInviteEvent( func (s *inviteEventsStatements) insertInviteEvent(
ctx context.Context, inviteEvent gomatrixserverlib.Event, ctx context.Context, inviteEvent gomatrixserverlib.Event,
) (streamPos int64, err error) { ) (streamPos types.StreamPosition, err error) {
err = s.insertInviteEventStmt.QueryRowContext( err = s.insertInviteEventStmt.QueryRowContext(
ctx, ctx,
inviteEvent.RoomID(), inviteEvent.RoomID(),
@ -107,7 +108,7 @@ func (s *inviteEventsStatements) deleteInviteEvent(
// selectInviteEventsInRange returns a map of room ID to invite event for the // selectInviteEventsInRange returns a map of room ID to invite event for the
// active invites for the target user ID in the supplied range. // active invites for the target user ID in the supplied range.
func (s *inviteEventsStatements) selectInviteEventsInRange( func (s *inviteEventsStatements) selectInviteEventsInRange(
ctx context.Context, txn *sql.Tx, targetUserID string, startPos, endPos int64, ctx context.Context, txn *sql.Tx, targetUserID string, startPos, endPos types.StreamPosition,
) (map[string]gomatrixserverlib.Event, error) { ) (map[string]gomatrixserverlib.Event, error) {
stmt := common.TxStmt(txn, s.selectInviteEventsInRangeStmt) stmt := common.TxStmt(txn, s.selectInviteEventsInRangeStmt)
rows, err := stmt.QueryContext(ctx, targetUserID, startPos, endPos) rows, err := stmt.QueryContext(ctx, targetUserID, startPos, endPos)

View file

@ -151,7 +151,7 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) {
// Results are bucketed based on the room ID. If the same state is overwritten multiple times between the // Results are bucketed based on the room ID. If the same state is overwritten multiple times between the
// two positions, only the most recent state is returned. // two positions, only the most recent state is returned.
func (s *outputRoomEventsStatements) selectStateInRange( func (s *outputRoomEventsStatements) selectStateInRange(
ctx context.Context, txn *sql.Tx, oldPos, newPos int64, ctx context.Context, txn *sql.Tx, oldPos, newPos types.StreamPosition,
stateFilterPart *gomatrix.FilterPart, stateFilterPart *gomatrix.FilterPart,
) (map[string]map[string]bool, map[string]types.StreamEvent, error) { ) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
stmt := common.TxStmt(txn, s.selectStateInRangeStmt) stmt := common.TxStmt(txn, s.selectStateInRangeStmt)
@ -180,7 +180,7 @@ func (s *outputRoomEventsStatements) selectStateInRange(
for rows.Next() { for rows.Next() {
var ( var (
streamPos int64 streamPos types.StreamPosition
eventBytes []byte eventBytes []byte
excludeFromSync bool excludeFromSync bool
addIDs pq.StringArray addIDs pq.StringArray
@ -248,7 +248,7 @@ func (s *outputRoomEventsStatements) insertEvent(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
event *gomatrixserverlib.Event, addState, removeState []string, event *gomatrixserverlib.Event, addState, removeState []string,
transactionID *api.TransactionID, excludeFromSync bool, transactionID *api.TransactionID, excludeFromSync bool,
) (streamPos int64, err error) { ) (streamPos types.StreamPosition, err error) {
var txnID *string var txnID *string
var sessionID *int64 var sessionID *int64
if transactionID != nil { if transactionID != nil {
@ -360,7 +360,7 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
var result []types.StreamEvent var result []types.StreamEvent
for rows.Next() { for rows.Next() {
var ( var (
streamPos int64 streamPos types.StreamPosition
eventBytes []byte eventBytes []byte
excludeFromSync bool excludeFromSync bool
sessionID *int64 sessionID *int64

View file

@ -43,7 +43,7 @@ type stateDelta struct {
membership string membership string
// The PDU stream position of the latest membership event for this user, if applicable. // The PDU stream position of the latest membership event for this user, if applicable.
// Can be 0 if there is no membership event in this delta. // Can be 0 if there is no membership event in this delta.
membershipPos int64 membershipPos types.StreamPosition
} }
// SyncServerDatasource represents a sync server datasource which manages // SyncServerDatasource represents a sync server datasource which manages
@ -122,7 +122,7 @@ func (d *SyncServerDatasource) WriteEvent(
addStateEvents []gomatrixserverlib.Event, addStateEvents []gomatrixserverlib.Event,
addStateEventIDs, removeStateEventIDs []string, addStateEventIDs, removeStateEventIDs []string,
transactionID *api.TransactionID, excludeFromSync bool, transactionID *api.TransactionID, excludeFromSync bool,
) (pduPosition int64, returnErr error) { ) (pduPosition types.StreamPosition, returnErr error) {
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error var err error
pos, err := d.events.insertEvent( pos, err := d.events.insertEvent(
@ -186,7 +186,7 @@ func (d *SyncServerDatasource) updateRoomState(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
removedEventIDs []string, removedEventIDs []string,
addedEvents []gomatrixserverlib.Event, addedEvents []gomatrixserverlib.Event,
pduPosition int64, pduPosition types.StreamPosition,
) error { ) error {
// remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.
for _, eventID := range removedEventIDs { for _, eventID := range removedEventIDs {
@ -256,12 +256,12 @@ func (d *SyncServerDatasource) GetEventsInRange(
if backwardOrdering { if backwardOrdering {
// Backward ordering is antichronological (latest event to oldest // Backward ordering is antichronological (latest event to oldest
// one). // one).
backwardLimit = to.Position backwardLimit = to.PDUPosition
forwardLimit = from.Position forwardLimit = from.PDUPosition
} else { } else {
// Forward ordering is chronological (oldest event to latest one). // Forward ordering is chronological (oldest event to latest one).
backwardLimit = from.Position backwardLimit = from.PDUPosition
forwardLimit = to.Position forwardLimit = to.PDUPosition
} }
// Select the event IDs from the defined range. // Select the event IDs from the defined range.
@ -285,14 +285,14 @@ func (d *SyncServerDatasource) GetEventsInRange(
if backwardOrdering { if backwardOrdering {
// When using backward ordering, we want the most recent events first. // When using backward ordering, we want the most recent events first.
if events, err = d.events.selectRecentEvents( if events, err = d.events.selectRecentEvents(
ctx, nil, roomID, to.Position, from.Position, limit, false, false, ctx, nil, roomID, to.PDUPosition, from.PDUPosition, limit, false, false,
); err != nil { ); err != nil {
return return
} }
} else { } else {
// When using forward ordering, we want the least recent events first. // When using forward ordering, we want the least recent events first.
if events, err = d.events.selectEarlyEvents( if events, err = d.events.selectEarlyEvents(
ctx, nil, roomID, from.Position, to.Position, limit, ctx, nil, roomID, from.PDUPosition, to.PDUPosition, limit,
); err != nil { ); err != nil {
return return
} }
@ -302,7 +302,7 @@ func (d *SyncServerDatasource) GetEventsInRange(
} }
// SyncPosition returns the latest positions for syncing. // SyncPosition returns the latest positions for syncing.
func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.SyncPosition, error) { func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.PaginationToken, error) {
return d.syncPositionTx(ctx, nil) return d.syncPositionTx(ctx, nil)
} }
@ -372,7 +372,7 @@ func (d *SyncServerDatasource) syncStreamPositionTx(
func (d *SyncServerDatasource) syncPositionTx( func (d *SyncServerDatasource) syncPositionTx(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
) (sp types.SyncPosition, err error) { ) (sp types.PaginationToken, err error) {
maxEventID, err := d.events.selectMaxEventID(ctx, txn) maxEventID, err := d.events.selectMaxEventID(ctx, txn)
if err != nil { if err != nil {
@ -392,10 +392,8 @@ func (d *SyncServerDatasource) syncPositionTx(
if maxInviteID > maxEventID { if maxInviteID > maxEventID {
maxEventID = maxInviteID maxEventID = maxInviteID
} }
sp.PDUPosition = maxEventID sp.PDUPosition = types.StreamPosition(maxEventID)
sp.EDUTypingPosition = types.StreamPosition(d.typingCache.GetLatestSyncPosition())
sp.TypingPosition = d.typingCache.GetLatestSyncPosition()
return return
} }
@ -404,7 +402,7 @@ func (d *SyncServerDatasource) syncPositionTx(
func (d *SyncServerDatasource) addPDUDeltaToResponse( func (d *SyncServerDatasource) addPDUDeltaToResponse(
ctx context.Context, ctx context.Context,
device authtypes.Device, device authtypes.Device,
fromPos, toPos int64, fromPos, toPos types.StreamPosition,
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
wantFullState bool, wantFullState bool,
res *types.Response, res *types.Response,
@ -456,7 +454,7 @@ func (d *SyncServerDatasource) addPDUDeltaToResponse(
// addTypingDeltaToResponse adds all typing notifications to a sync response // addTypingDeltaToResponse adds all typing notifications to a sync response
// since the specified position. // since the specified position.
func (d *SyncServerDatasource) addTypingDeltaToResponse( func (d *SyncServerDatasource) addTypingDeltaToResponse(
since int64, since types.PaginationToken,
joinedRoomIDs []string, joinedRoomIDs []string,
res *types.Response, res *types.Response,
) error { ) error {
@ -465,7 +463,7 @@ func (d *SyncServerDatasource) addTypingDeltaToResponse(
var err error var err error
for _, roomID := range joinedRoomIDs { for _, roomID := range joinedRoomIDs {
if typingUsers, updated := d.typingCache.GetTypingUsersIfUpdatedAfter( if typingUsers, updated := d.typingCache.GetTypingUsersIfUpdatedAfter(
roomID, since, roomID, int64(since.EDUTypingPosition),
); updated { ); updated {
ev := gomatrixserverlib.ClientEvent{ ev := gomatrixserverlib.ClientEvent{
Type: gomatrixserverlib.MTyping, Type: gomatrixserverlib.MTyping,
@ -490,14 +488,14 @@ func (d *SyncServerDatasource) addTypingDeltaToResponse(
// addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if // addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if
// the positions of that type are not equal in fromPos and toPos. // the positions of that type are not equal in fromPos and toPos.
func (d *SyncServerDatasource) addEDUDeltaToResponse( func (d *SyncServerDatasource) addEDUDeltaToResponse(
fromPos, toPos types.SyncPosition, fromPos, toPos types.PaginationToken,
joinedRoomIDs []string, joinedRoomIDs []string,
res *types.Response, res *types.Response,
) (err error) { ) (err error) {
if fromPos.TypingPosition != toPos.TypingPosition { if fromPos.EDUTypingPosition != toPos.EDUTypingPosition {
err = d.addTypingDeltaToResponse( err = d.addTypingDeltaToResponse(
fromPos.TypingPosition, joinedRoomIDs, res, fromPos, joinedRoomIDs, res,
) )
} }
@ -512,7 +510,7 @@ func (d *SyncServerDatasource) addEDUDeltaToResponse(
func (d *SyncServerDatasource) IncrementalSync( func (d *SyncServerDatasource) IncrementalSync(
ctx context.Context, ctx context.Context,
device authtypes.Device, device authtypes.Device,
fromPos, toPos types.SyncPosition, fromPos, toPos types.PaginationToken,
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
wantFullState bool, wantFullState bool,
) (*types.Response, error) { ) (*types.Response, error) {
@ -552,7 +550,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
) ( ) (
res *types.Response, res *types.Response,
toPos types.SyncPosition, toPos types.PaginationToken,
joinedRoomIDs []string, joinedRoomIDs []string,
err error, err error,
) { ) {
@ -577,7 +575,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
// Get the current stream position which we will base the sync response on. // Get the current stream position which we will base the sync response on.
pos, err := d.syncStreamPositionTx(ctx, txn) pos, err := d.syncStreamPositionTx(ctx, txn)
if err != nil { if err != nil {
return nil, types.SyncPosition{}, []string{}, err return nil, types.PaginationToken{}, []string{}, err
} }
res = types.NewResponse(toPos) res = types.NewResponse(toPos)
@ -614,7 +612,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
var backwardTopologyPos types.StreamPosition var backwardTopologyPos types.StreamPosition
backwardTopologyPos, err = d.topology.selectPositionInTopology(ctx, recentStreamEvents[0].EventID()) backwardTopologyPos, err = d.topology.selectPositionInTopology(ctx, recentStreamEvents[0].EventID())
if err != nil { if err != nil {
return nil, types.SyncPosition{}, []string{}, err return nil, types.PaginationToken{}, []string{}, err
} }
if backwardTopologyPos-1 <= 0 { if backwardTopologyPos-1 <= 0 {
backwardTopologyPos = types.StreamPosition(1) backwardTopologyPos = types.StreamPosition(1)
@ -628,11 +626,11 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
stateEvents = removeDuplicates(stateEvents, recentEvents) stateEvents = removeDuplicates(stateEvents, recentEvents)
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition(
types.PaginationTokenTypeTopology, backwardTopologyPos, types.PaginationTokenTypeTopology, backwardTopologyPos, 0,
).String() ).String()
if prevPDUPos := recentStreamEvents[0].StreamPosition - 1; prevPDUPos > 0 { if prevPDUPos := recentStreamEvents[0].StreamPosition - 1; prevPDUPos > 0 {
// Use the short form of batch token for prev_batch // Use the short form of batch token for prev_batch
jr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10) jr.Timeline.PrevBatch = strconv.FormatInt(int64(prevPDUPos), 10)
} else { } else {
// Use the short form of batch token for prev_batch // Use the short form of batch token for prev_batch
jr.Timeline.PrevBatch = "1" jr.Timeline.PrevBatch = "1"
@ -664,7 +662,7 @@ func (d *SyncServerDatasource) CompleteSync(
// Use a zero value SyncPosition for fromPos so all EDU states are added. // Use a zero value SyncPosition for fromPos so all EDU states are added.
err = d.addEDUDeltaToResponse( err = d.addEDUDeltaToResponse(
types.SyncPosition{}, toPos, joinedRoomIDs, res, types.PaginationToken{}, toPos, joinedRoomIDs, res,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -689,7 +687,7 @@ var txReadOnlySnapshot = sql.TxOptions{
// If no data is retrieved, returns an empty map // If no data is retrieved, returns an empty map
// If there was an issue with the retrieval, returns an error // If there was an issue with the retrieval, returns an error
func (d *SyncServerDatasource) GetAccountDataInRange( func (d *SyncServerDatasource) GetAccountDataInRange(
ctx context.Context, userID string, oldPos, newPos int64, ctx context.Context, userID string, oldPos, newPos types.StreamPosition,
accountDataFilterPart *gomatrix.FilterPart, accountDataFilterPart *gomatrix.FilterPart,
) (map[string][]string, error) { ) (map[string][]string, error) {
return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos, accountDataFilterPart) return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos, accountDataFilterPart)
@ -703,7 +701,7 @@ func (d *SyncServerDatasource) GetAccountDataInRange(
// Returns an error if there was an issue with the upsert // Returns an error if there was an issue with the upsert
func (d *SyncServerDatasource) UpsertAccountData( func (d *SyncServerDatasource) UpsertAccountData(
ctx context.Context, userID, roomID, dataType string, ctx context.Context, userID, roomID, dataType string,
) (int64, error) { ) (types.StreamPosition, error) {
return d.accountData.insertAccountData(ctx, userID, roomID, dataType) return d.accountData.insertAccountData(ctx, userID, roomID, dataType)
} }
@ -712,7 +710,7 @@ func (d *SyncServerDatasource) UpsertAccountData(
// Returns an error if there was a problem communicating with the database. // Returns an error if there was a problem communicating with the database.
func (d *SyncServerDatasource) AddInviteEvent( func (d *SyncServerDatasource) AddInviteEvent(
ctx context.Context, inviteEvent gomatrixserverlib.Event, ctx context.Context, inviteEvent gomatrixserverlib.Event,
) (int64, error) { ) (types.StreamPosition, error) {
return d.invites.insertInviteEvent(ctx, inviteEvent) return d.invites.insertInviteEvent(ctx, inviteEvent)
} }
@ -735,26 +733,26 @@ func (d *SyncServerDatasource) SetTypingTimeoutCallback(fn cache.TimeoutCallback
// Returns the newly calculated sync position for typing notifications. // Returns the newly calculated sync position for typing notifications.
func (d *SyncServerDatasource) AddTypingUser( func (d *SyncServerDatasource) AddTypingUser(
userID, roomID string, expireTime *time.Time, userID, roomID string, expireTime *time.Time,
) int64 { ) types.StreamPosition {
return d.typingCache.AddTypingUser(userID, roomID, expireTime) return types.StreamPosition(d.typingCache.AddTypingUser(userID, roomID, expireTime))
} }
// RemoveTypingUser removes a typing user from the typing cache. // RemoveTypingUser removes a typing user from the typing cache.
// Returns the newly calculated sync position for typing notifications. // Returns the newly calculated sync position for typing notifications.
func (d *SyncServerDatasource) RemoveTypingUser( func (d *SyncServerDatasource) RemoveTypingUser(
userID, roomID string, userID, roomID string,
) int64 { ) types.StreamPosition {
return d.typingCache.RemoveUser(userID, roomID) return types.StreamPosition(d.typingCache.RemoveUser(userID, roomID))
} }
func (d *SyncServerDatasource) addInvitesToResponse( func (d *SyncServerDatasource) addInvitesToResponse(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
userID string, userID string,
fromPos, toPos int64, fromPos, toPos types.StreamPosition,
res *types.Response, res *types.Response,
) error { ) error {
invites, err := d.invites.selectInviteEventsInRange( invites, err := d.invites.selectInviteEventsInRange(
ctx, txn, userID, int64(fromPos), int64(toPos), ctx, txn, userID, fromPos, toPos,
) )
if err != nil { if err != nil {
return err return err
@ -775,7 +773,7 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse(
ctx context.Context, ctx context.Context,
device *authtypes.Device, device *authtypes.Device,
txn *sql.Tx, txn *sql.Tx,
fromPos, toPos int64, fromPos, toPos types.StreamPosition,
delta stateDelta, delta stateDelta,
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
res *types.Response, res *types.Response,
@ -800,7 +798,7 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse(
recentEvents := d.StreamEventsToEvents(device, recentStreamEvents) recentEvents := d.StreamEventsToEvents(device, recentStreamEvents)
delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back
var prevPDUPos int64 var prevPDUPos types.StreamPosition
if len(recentEvents) == 0 { if len(recentEvents) == 0 {
if len(delta.stateEvents) == 0 { if len(delta.stateEvents) == 0 {
@ -837,7 +835,7 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse(
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition(
types.PaginationTokenTypeTopology, backwardTopologyPos, types.PaginationTokenTypeTopology, backwardTopologyPos, 0,
).String() ).String()
// Use the short form of batch token for prev_batch // Use the short form of batch token for prev_batch
//jr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10) //jr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10)
@ -852,7 +850,7 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse(
// no longer in the room. // no longer in the room.
lr := types.NewLeaveResponse() lr := types.NewLeaveResponse()
lr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( lr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition(
types.PaginationTokenTypeStream, backwardTopologyPos, types.PaginationTokenTypeStream, backwardTopologyPos, 0,
).String() ).String()
// Use the short form of batch token for prev_batch // Use the short form of batch token for prev_batch
//lr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10) //lr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10)
@ -957,7 +955,7 @@ func (d *SyncServerDatasource) fetchMissingStateEvents(
// A list of joined room IDs is also returned in case the caller needs it. // A list of joined room IDs is also returned in case the caller needs it.
func (d *SyncServerDatasource) getStateDeltas( func (d *SyncServerDatasource) getStateDeltas(
ctx context.Context, device *authtypes.Device, txn *sql.Tx, ctx context.Context, device *authtypes.Device, txn *sql.Tx,
fromPos, toPos int64, userID string, fromPos, toPos types.StreamPosition, userID string,
stateFilterPart *gomatrix.FilterPart, stateFilterPart *gomatrix.FilterPart,
) ([]stateDelta, []string, error) { ) ([]stateDelta, []string, error) {
// Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
@ -1038,7 +1036,7 @@ func (d *SyncServerDatasource) getStateDeltas(
// updates for other rooms. // updates for other rooms.
func (d *SyncServerDatasource) getStateDeltasForFullStateSync( func (d *SyncServerDatasource) getStateDeltasForFullStateSync(
ctx context.Context, device *authtypes.Device, txn *sql.Tx, ctx context.Context, device *authtypes.Device, txn *sql.Tx,
fromPos, toPos int64, userID string, fromPos, toPos types.StreamPosition, userID string,
stateFilterPart *gomatrix.FilterPart, stateFilterPart *gomatrix.FilterPart,
) ([]stateDelta, []string, error) { ) ([]stateDelta, []string, error) {
joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)

View file

@ -33,19 +33,19 @@ type Database interface {
common.PartitionStorer common.PartitionStorer
AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error)
Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error)
WriteEvent(context.Context, *gomatrixserverlib.Event, []gomatrixserverlib.Event, []string, []string, *api.TransactionID, bool) (int64, error) WriteEvent(context.Context, *gomatrixserverlib.Event, []gomatrixserverlib.Event, []string, []string, *api.TransactionID, bool) (types.StreamPosition, error)
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.Event, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.Event, error)
GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrix.FilterPart) (stateEvents []gomatrixserverlib.Event, err error) GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrix.FilterPart) (stateEvents []gomatrixserverlib.Event, err error)
SyncPosition(ctx context.Context) (types.SyncPosition, error) SyncPosition(ctx context.Context) (types.PaginationToken, error)
IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.SyncPosition, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error) IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.PaginationToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error)
CompleteSync(ctx context.Context, userID string, numRecentEventsPerRoom int) (*types.Response, error) CompleteSync(ctx context.Context, userID string, numRecentEventsPerRoom int) (*types.Response, error)
GetAccountDataInRange(ctx context.Context, userID string, oldPos, newPos int64, accountDataFilterPart *gomatrix.FilterPart) (map[string][]string, error) GetAccountDataInRange(ctx context.Context, userID string, oldPos, newPos types.StreamPosition, accountDataFilterPart *gomatrix.FilterPart) (map[string][]string, error)
UpsertAccountData(ctx context.Context, userID, roomID, dataType string) (int64, error) UpsertAccountData(ctx context.Context, userID, roomID, dataType string) (types.StreamPosition, error)
AddInviteEvent(ctx context.Context, inviteEvent gomatrixserverlib.Event) (int64, error) AddInviteEvent(ctx context.Context, inviteEvent gomatrixserverlib.Event) (types.StreamPosition, error)
RetireInviteEvent(ctx context.Context, inviteEventID string) error RetireInviteEvent(ctx context.Context, inviteEventID string) error
SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn)
AddTypingUser(userID, roomID string, expireTime *time.Time) int64 AddTypingUser(userID, roomID string, expireTime *time.Time) types.StreamPosition
RemoveTypingUser(userID, roomID string) int64 RemoveTypingUser(userID, roomID string) types.StreamPosition
GetEventsInRange(ctx context.Context, from, to *types.PaginationToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) GetEventsInRange(ctx context.Context, from, to *types.PaginationToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error)
EventPositionInTopology(ctx context.Context, eventID string) (types.StreamPosition, error) EventPositionInTopology(ctx context.Context, eventID string) (types.StreamPosition, error)
BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities []string, err error) BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities []string, err error)

View file

@ -36,7 +36,7 @@ type Notifier struct {
// Protects currPos and userStreams. // Protects currPos and userStreams.
streamLock *sync.Mutex streamLock *sync.Mutex
// The latest sync position // The latest sync position
currPos types.SyncPosition currPos types.PaginationToken
// A map of user_id => UserStream which can be used to wake a given user's /sync request. // A map of user_id => UserStream which can be used to wake a given user's /sync request.
userStreams map[string]*UserStream userStreams map[string]*UserStream
// The last time we cleaned out stale entries from the userStreams map // The last time we cleaned out stale entries from the userStreams map
@ -46,7 +46,7 @@ type Notifier struct {
// NewNotifier creates a new notifier set to the given sync position. // NewNotifier creates a new notifier set to the given sync position.
// In order for this to be of any use, the Notifier needs to be told all rooms and // In order for this to be of any use, the Notifier needs to be told all rooms and
// the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase). // the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase).
func NewNotifier(pos types.SyncPosition) *Notifier { func NewNotifier(pos types.PaginationToken) *Notifier {
return &Notifier{ return &Notifier{
currPos: pos, currPos: pos,
roomIDToJoinedUsers: make(map[string]userIDSet), roomIDToJoinedUsers: make(map[string]userIDSet),
@ -68,7 +68,7 @@ func NewNotifier(pos types.SyncPosition) *Notifier {
// event type it handles, leaving other fields as 0. // event type it handles, leaving other fields as 0.
func (n *Notifier) OnNewEvent( func (n *Notifier) OnNewEvent(
ev *gomatrixserverlib.Event, roomID string, userIDs []string, ev *gomatrixserverlib.Event, roomID string, userIDs []string,
posUpdate types.SyncPosition, posUpdate types.PaginationToken,
) { ) {
// update the current position then notify relevant /sync streams. // update the current position then notify relevant /sync streams.
// This needs to be done PRIOR to waking up users as they will read this value. // This needs to be done PRIOR to waking up users as they will read this value.
@ -151,7 +151,7 @@ func (n *Notifier) Load(ctx context.Context, db storage.Database) error {
} }
// CurrentPosition returns the current sync position // CurrentPosition returns the current sync position
func (n *Notifier) CurrentPosition() types.SyncPosition { func (n *Notifier) CurrentPosition() types.PaginationToken {
n.streamLock.Lock() n.streamLock.Lock()
defer n.streamLock.Unlock() defer n.streamLock.Unlock()
@ -173,7 +173,7 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) {
} }
} }
func (n *Notifier) wakeupUsers(userIDs []string, newPos types.SyncPosition) { func (n *Notifier) wakeupUsers(userIDs []string, newPos types.PaginationToken) {
for _, userID := range userIDs { for _, userID := range userIDs {
stream := n.fetchUserStream(userID, false) stream := n.fetchUserStream(userID, false)
if stream != nil { if stream != nil {

View file

@ -32,11 +32,11 @@ var (
randomMessageEvent gomatrixserverlib.Event randomMessageEvent gomatrixserverlib.Event
aliceInviteBobEvent gomatrixserverlib.Event aliceInviteBobEvent gomatrixserverlib.Event
bobLeaveEvent gomatrixserverlib.Event bobLeaveEvent gomatrixserverlib.Event
syncPositionVeryOld types.SyncPosition syncPositionVeryOld types.PaginationToken
syncPositionBefore types.SyncPosition syncPositionBefore types.PaginationToken
syncPositionAfter types.SyncPosition syncPositionAfter types.PaginationToken
syncPositionNewEDU types.SyncPosition syncPositionNewEDU types.PaginationToken
syncPositionAfter2 types.SyncPosition syncPositionAfter2 types.PaginationToken
) )
var ( var (
@ -46,9 +46,9 @@ var (
) )
func init() { func init() {
baseSyncPos := types.SyncPosition{ baseSyncPos := types.PaginationToken{
PDUPosition: 0, PDUPosition: 0,
TypingPosition: 0, EDUTypingPosition: 0,
} }
syncPositionVeryOld = baseSyncPos syncPositionVeryOld = baseSyncPos
@ -61,7 +61,7 @@ func init() {
syncPositionAfter.PDUPosition = 12 syncPositionAfter.PDUPosition = 12
syncPositionNewEDU = syncPositionAfter syncPositionNewEDU = syncPositionAfter
syncPositionNewEDU.TypingPosition = 1 syncPositionNewEDU.EDUTypingPosition = 1
syncPositionAfter2 = baseSyncPos syncPositionAfter2 = baseSyncPos
syncPositionAfter2.PDUPosition = 13 syncPositionAfter2.PDUPosition = 13
@ -119,7 +119,7 @@ func TestImmediateNotification(t *testing.T) {
t.Fatalf("TestImmediateNotification error: %s", err) t.Fatalf("TestImmediateNotification error: %s", err)
} }
if pos != syncPositionBefore { if pos != syncPositionBefore {
t.Fatalf("TestImmediateNotification want %d, got %d", syncPositionBefore, pos) t.Fatalf("TestImmediateNotification want %v, got %v", syncPositionBefore, pos)
} }
} }
@ -138,7 +138,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) {
t.Errorf("TestNewEventAndJoinedToRoom error: %s", err) t.Errorf("TestNewEventAndJoinedToRoom error: %s", err)
} }
if pos != syncPositionAfter { if pos != syncPositionAfter {
t.Errorf("TestNewEventAndJoinedToRoom want %d, got %d", syncPositionAfter, pos) t.Errorf("TestNewEventAndJoinedToRoom want %v, got %v", syncPositionAfter, pos)
} }
wg.Done() wg.Done()
}() }()
@ -166,7 +166,7 @@ func TestNewInviteEventForUser(t *testing.T) {
t.Errorf("TestNewInviteEventForUser error: %s", err) t.Errorf("TestNewInviteEventForUser error: %s", err)
} }
if pos != syncPositionAfter { if pos != syncPositionAfter {
t.Errorf("TestNewInviteEventForUser want %d, got %d", syncPositionAfter, pos) t.Errorf("TestNewInviteEventForUser want %v, got %v", syncPositionAfter, pos)
} }
wg.Done() wg.Done()
}() }()
@ -194,7 +194,7 @@ func TestEDUWakeup(t *testing.T) {
t.Errorf("TestNewInviteEventForUser error: %s", err) t.Errorf("TestNewInviteEventForUser error: %s", err)
} }
if pos != syncPositionNewEDU { if pos != syncPositionNewEDU {
t.Errorf("TestNewInviteEventForUser want %d, got %d", syncPositionNewEDU, pos) t.Errorf("TestNewInviteEventForUser want %v, got %v", syncPositionNewEDU, pos)
} }
wg.Done() wg.Done()
}() }()
@ -222,7 +222,7 @@ func TestMultipleRequestWakeup(t *testing.T) {
t.Errorf("TestMultipleRequestWakeup error: %s", err) t.Errorf("TestMultipleRequestWakeup error: %s", err)
} }
if pos != syncPositionAfter { if pos != syncPositionAfter {
t.Errorf("TestMultipleRequestWakeup want %d, got %d", syncPositionAfter, pos) t.Errorf("TestMultipleRequestWakeup want %v, got %v", syncPositionAfter, pos)
} }
wg.Done() wg.Done()
} }
@ -262,7 +262,7 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err) t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err)
} }
if pos != syncPositionAfter { if pos != syncPositionAfter {
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", syncPositionAfter, pos) t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %v, got %v", syncPositionAfter, pos)
} }
leaveWG.Done() leaveWG.Done()
}() }()
@ -281,7 +281,7 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err) t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err)
} }
if pos != syncPositionAfter2 { if pos != syncPositionAfter2 {
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", syncPositionAfter2, pos) t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %v, got %v", syncPositionAfter2, pos)
} }
aliceWG.Done() aliceWG.Done()
}() }()
@ -305,14 +305,14 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
} }
func waitForEvents(n *Notifier, req syncRequest) (types.SyncPosition, error) { func waitForEvents(n *Notifier, req syncRequest) (types.PaginationToken, error) {
listener := n.GetListener(req) listener := n.GetListener(req)
defer listener.Close() defer listener.Close()
select { select {
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
return types.SyncPosition{}, fmt.Errorf( return types.PaginationToken{}, fmt.Errorf(
"waitForEvents timed out waiting for %s (pos=%d)", req.device.UserID, req.since, "waitForEvents timed out waiting for %s (pos=%v)", req.device.UserID, req.since,
) )
case <-listener.GetNotifyChannel(*req.since): case <-listener.GetNotifyChannel(*req.since):
p := listener.GetSyncPosition() p := listener.GetSyncPosition()
@ -337,7 +337,7 @@ func lockedFetchUserStream(n *Notifier, userID string) *UserStream {
return n.fetchUserStream(userID, true) return n.fetchUserStream(userID, true)
} }
func newTestSyncRequest(userID string, since types.SyncPosition) syncRequest { func newTestSyncRequest(userID string, since types.PaginationToken) syncRequest {
return syncRequest{ return syncRequest{
device: authtypes.Device{UserID: userID}, device: authtypes.Device{UserID: userID},
timeout: 1 * time.Minute, timeout: 1 * time.Minute,

View file

@ -16,11 +16,9 @@ package sync
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -45,7 +43,7 @@ type syncRequest struct {
device authtypes.Device device authtypes.Device
limit int limit int
timeout time.Duration timeout time.Duration
since *types.SyncPosition // nil means that no since token was supplied since *types.PaginationToken // nil means that no since token was supplied
wantFullState bool wantFullState bool
log *log.Entry log *log.Entry
} }
@ -95,45 +93,23 @@ func getPaginationToken(since string) (*types.StreamPosition, error) {
if p.Type != types.PaginationTokenTypeStream { if p.Type != types.PaginationTokenTypeStream {
return nil, ErrNotStreamToken return nil, ErrNotStreamToken
} }
return &(p.Position), nil return &(p.PDUPosition), nil
} }
// getSyncStreamPosition tries to parse a 'since' token taken from the API to a // getSyncStreamPosition tries to parse a 'since' token taken from the API to a
// types.SyncPosition. If the string is empty then (nil, nil) is returned. // types.PaginationToken. If the string is empty then (nil, nil) is returned.
// There are two forms of tokens: The full length form containing all PDU and EDU // There are two forms of tokens: The full length form containing all PDU and EDU
// positions separated by "_", and the short form containing only the PDU // positions separated by "_", and the short form containing only the PDU
// position. Short form can be used for, e.g., `prev_batch` tokens. // position. Short form can be used for, e.g., `prev_batch` tokens.
func getSyncStreamPosition(since string) (*types.SyncPosition, error) { func getSyncStreamPosition(since string) (*types.PaginationToken, error) {
if since == "" { if since == "" {
return nil, nil return nil, nil
} }
posStrings := strings.Split(since, "_") pos, err := types.NewPaginationTokenFromString(since)
if len(posStrings) != 2 && len(posStrings) != 1 { if err != nil {
// A token can either be full length or short (PDU-only). return nil, err
return nil, errors.New("malformed batch token")
} }
positions := make([]int64, len(posStrings)) return pos, nil
for i, posString := range posStrings {
pos, err := strconv.ParseInt(posString, 10, 64)
if err != nil {
return nil, err
}
positions[i] = pos
}
if len(positions) == 2 {
// Full length token; construct SyncPosition with every entry in
// `positions`. These entries must have the same order with the fields
// in struct SyncPosition, so we disable the govet check below.
return &types.SyncPosition{ //nolint:govet
positions[0], positions[1],
}, nil
} else {
// Token with PDU position only
return &types.SyncPosition{
PDUPosition: positions[0],
}, nil
}
} }

View file

@ -130,7 +130,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
} }
} }
func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.SyncPosition) (res *types.Response, err error) { func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.PaginationToken) (res *types.Response, err error) {
// TODO: handle ignored users // TODO: handle ignored users
if req.since == nil { if req.since == nil {
res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit) res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit)
@ -143,7 +143,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.SyncP
} }
accountDataFilter := gomatrix.DefaultFilterPart() // TODO: use filter provided in req instead accountDataFilter := gomatrix.DefaultFilterPart() // TODO: use filter provided in req instead
res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition, &accountDataFilter) res, err = rp.appendAccountData(res, req.device.UserID, req, int64(latestPos.PDUPosition), &accountDataFilter)
return return
} }
@ -183,7 +183,11 @@ func (rp *RequestPool) appendAccountData(
} }
// Sync is not initial, get all account data since the latest sync // Sync is not initial, get all account data since the latest sync
dataTypes, err := rp.db.GetAccountDataInRange(req.ctx, userID, req.since.PDUPosition, currentPos, accountDataFilter) dataTypes, err := rp.db.GetAccountDataInRange(
req.ctx, userID,
types.StreamPosition(req.since.PDUPosition), types.StreamPosition(currentPos),
accountDataFilter,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -35,7 +35,7 @@ type UserStream struct {
// Closed when there is an update. // Closed when there is an update.
signalChannel chan struct{} signalChannel chan struct{}
// The last sync position that there may have been an update for the user // The last sync position that there may have been an update for the user
pos types.SyncPosition pos types.PaginationToken
// The last time when we had some listeners waiting // The last time when we had some listeners waiting
timeOfLastChannel time.Time timeOfLastChannel time.Time
// The number of listeners waiting // The number of listeners waiting
@ -51,7 +51,7 @@ type UserStreamListener struct {
} }
// NewUserStream creates a new user stream // NewUserStream creates a new user stream
func NewUserStream(userID string, currPos types.SyncPosition) *UserStream { func NewUserStream(userID string, currPos types.PaginationToken) *UserStream {
return &UserStream{ return &UserStream{
UserID: userID, UserID: userID,
timeOfLastChannel: time.Now(), timeOfLastChannel: time.Now(),
@ -85,7 +85,7 @@ func (s *UserStream) GetListener(ctx context.Context) UserStreamListener {
} }
// Broadcast a new sync position for this user. // Broadcast a new sync position for this user.
func (s *UserStream) Broadcast(pos types.SyncPosition) { func (s *UserStream) Broadcast(pos types.PaginationToken) {
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
@ -120,7 +120,7 @@ func (s *UserStream) TimeOfLastNonEmpty() time.Time {
// GetStreamPosition returns last sync position which the UserStream was // GetStreamPosition returns last sync position which the UserStream was
// notified about // notified about
func (s *UserStreamListener) GetSyncPosition() types.SyncPosition { func (s *UserStreamListener) GetSyncPosition() types.PaginationToken {
s.userStream.lock.Lock() s.userStream.lock.Lock()
defer s.userStream.lock.Unlock() defer s.userStream.lock.Unlock()
@ -132,7 +132,7 @@ func (s *UserStreamListener) GetSyncPosition() types.SyncPosition {
// sincePos specifies from which point we want to be notified about. If there // sincePos specifies from which point we want to be notified about. If there
// has already been an update after sincePos we'll return a closed channel // has already been an update after sincePos we'll return a closed channel
// immediately. // immediately.
func (s *UserStreamListener) GetNotifyChannel(sincePos types.SyncPosition) <-chan struct{} { func (s *UserStreamListener) GetNotifyChannel(sincePos types.PaginationToken) <-chan struct{} {
s.userStream.lock.Lock() s.userStream.lock.Lock()
defer s.userStream.lock.Unlock() defer s.userStream.lock.Unlock()

View file

@ -18,6 +18,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"strconv" "strconv"
"strings"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -33,48 +34,14 @@ var (
// StreamPosition represents the offset in the sync stream a client is at. // StreamPosition represents the offset in the sync stream a client is at.
type StreamPosition int64 type StreamPosition int64
// SyncPosition contains the PDU and EDU stream sync positions for a client.
type SyncPosition struct {
// PDUPosition is the stream position for PDUs the client is at.
PDUPosition int64
// TypingPosition is the client's position for typing notifications.
TypingPosition int64
}
// Same as gomatrixserverlib.Event but also has the PDU stream position for this event. // Same as gomatrixserverlib.Event but also has the PDU stream position for this event.
type StreamEvent struct { type StreamEvent struct {
gomatrixserverlib.Event gomatrixserverlib.Event
StreamPosition int64 StreamPosition StreamPosition
TransactionID *api.TransactionID TransactionID *api.TransactionID
ExcludeFromSync bool ExcludeFromSync bool
} }
// String implements the Stringer interface.
func (sp SyncPosition) String() string {
return strconv.FormatInt(sp.PDUPosition, 10) + "_" +
strconv.FormatInt(sp.TypingPosition, 10)
}
// IsAfter returns whether one SyncPosition refers to states newer than another SyncPosition.
func (sp SyncPosition) IsAfter(other SyncPosition) bool {
return sp.PDUPosition > other.PDUPosition ||
sp.TypingPosition > other.TypingPosition
}
// WithUpdates returns a copy of the SyncPosition with updates applied from another SyncPosition.
// If the latter SyncPosition contains a field that is not 0, it is considered an update,
// and its value will replace the corresponding value in the SyncPosition on which WithUpdates is called.
func (sp SyncPosition) WithUpdates(other SyncPosition) SyncPosition {
ret := sp
if other.PDUPosition != 0 {
ret.PDUPosition = other.PDUPosition
}
if other.TypingPosition != 0 {
ret.TypingPosition = other.TypingPosition
}
return ret
}
// PaginationTokenType represents the type of a pagination token. // PaginationTokenType represents the type of a pagination token.
// It can be either "s" (representing a position in the whole stream of events) // It can be either "s" (representing a position in the whole stream of events)
// or "t" (representing a position in a room's topology/depth). // or "t" (representing a position in a room's topology/depth).
@ -91,8 +58,10 @@ const (
// PaginationToken represents a pagination token, used for interactions with // PaginationToken represents a pagination token, used for interactions with
// /sync or /messages, for example. // /sync or /messages, for example.
type PaginationToken struct { type PaginationToken struct {
Position StreamPosition //Position StreamPosition
Type PaginationTokenType Type PaginationTokenType
PDUPosition StreamPosition
EDUTypingPosition StreamPosition
} }
// NewPaginationTokenFromString takes a string of the form "xyyyy..." where "x" // NewPaginationTokenFromString takes a string of the form "xyyyy..." where "x"
@ -104,17 +73,32 @@ type PaginationToken struct {
func NewPaginationTokenFromString(s string) (p *PaginationToken, err error) { func NewPaginationTokenFromString(s string) (p *PaginationToken, err error) {
p = new(PaginationToken) p = new(PaginationToken)
// Parse the token (aka position).
position, err := strconv.ParseInt(s[1:], 10, 64)
if err != nil {
return
}
p.Position = StreamPosition(position)
// Check if the type is among the known ones. // Check if the type is among the known ones.
p.Type = PaginationTokenType(s[:1]) p.Type = PaginationTokenType(s[:1])
if p.Type != PaginationTokenTypeStream && p.Type != PaginationTokenTypeTopology { if p.Type != PaginationTokenTypeStream && p.Type != PaginationTokenTypeTopology {
err = ErrInvalidPaginationTokenType err = ErrInvalidPaginationTokenType
return
}
// Parse the token (aka position).
positions := strings.Split(s[:1], "_")
// Try to get the PDU position.
if len(positions) >= 1 {
if pduPos, err := strconv.ParseInt(positions[0], 10, 64); err != nil {
return nil, err
} else {
p.PDUPosition = StreamPosition(pduPos)
}
}
// Try to get the typing position.
if len(positions) >= 2 {
if typPos, err := strconv.ParseInt(positions[1], 10, 64); err != nil {
return nil, err
} else {
p.EDUTypingPosition = StreamPosition(typPos)
}
} }
return return
@ -123,18 +107,39 @@ func NewPaginationTokenFromString(s string) (p *PaginationToken, err error) {
// NewPaginationTokenFromTypeAndPosition takes a PaginationTokenType and a // NewPaginationTokenFromTypeAndPosition takes a PaginationTokenType and a
// StreamPosition and returns an instance of PaginationToken. // StreamPosition and returns an instance of PaginationToken.
func NewPaginationTokenFromTypeAndPosition( func NewPaginationTokenFromTypeAndPosition(
t PaginationTokenType, pos StreamPosition, t PaginationTokenType, pdupos StreamPosition, typpos StreamPosition,
) (p *PaginationToken) { ) (p *PaginationToken) {
return &PaginationToken{ return &PaginationToken{
Type: t, Type: t,
Position: pos, PDUPosition: pdupos,
EDUTypingPosition: typpos,
} }
} }
// String translates a PaginationToken to a string of the "xyyyy..." (see // String translates a PaginationToken to a string of the "xyyyy..." (see
// NewPaginationToken to know what it represents). // NewPaginationToken to know what it represents).
func (p *PaginationToken) String() string { func (p *PaginationToken) String() string {
return fmt.Sprintf("%s%d", p.Type, p.Position) return fmt.Sprintf("%s%d_%d", p.Type, p.PDUPosition, p.EDUTypingPosition)
}
// WithUpdates returns a copy of the SyncPosition with updates applied from another SyncPosition.
// If the latter SyncPosition contains a field that is not 0, it is considered an update,
// and its value will replace the corresponding value in the SyncPosition on which WithUpdates is called.
func (sp *PaginationToken) WithUpdates(other PaginationToken) PaginationToken {
ret := *sp
if other.PDUPosition != 0 {
ret.PDUPosition = other.PDUPosition
}
if other.EDUTypingPosition != 0 {
ret.EDUTypingPosition = other.EDUTypingPosition
}
return ret
}
// IsAfter returns whether one SyncPosition refers to states newer than another SyncPosition.
func (sp *PaginationToken) IsAfter(other PaginationToken) bool {
return sp.PDUPosition > other.PDUPosition ||
sp.EDUTypingPosition > other.EDUTypingPosition
} }
// PrevEventRef represents a reference to a previous event in a state event upgrade // PrevEventRef represents a reference to a previous event in a state event upgrade
@ -161,7 +166,7 @@ type Response struct {
} }
// NewResponse creates an empty response with initialised maps. // NewResponse creates an empty response with initialised maps.
func NewResponse(pos SyncPosition) *Response { func NewResponse(pos PaginationToken) *Response {
res := Response{ res := Response{
NextBatch: pos.String(), NextBatch: pos.String(),
} }
@ -180,7 +185,11 @@ func NewResponse(pos SyncPosition) *Response {
// Fill next_batch with a pagination token. Since this is a response to a sync request, we can assume // Fill next_batch with a pagination token. Since this is a response to a sync request, we can assume
// we'll always return a stream token. // we'll always return a stream token.
//res.NextBatch = NewPaginationTokenFromTypeAndPosition(PaginationTokenTypeStream, pos).String() res.NextBatch = NewPaginationTokenFromTypeAndPosition(
PaginationTokenTypeStream,
StreamPosition(pos.PDUPosition),
StreamPosition(pos.EDUTypingPosition),
).String()
return &res return &res
} }