Share more code

This commit is contained in:
Kegan Dougal 2020-05-14 15:12:41 +01:00
parent de9a48f76d
commit 83b0a0c681
2 changed files with 17 additions and 83 deletions

View file

@ -348,8 +348,16 @@ func (d *Database) GetEventsInTopologicalRange(
return return
} }
func (d *Database) SyncPosition(ctx context.Context) (types.StreamingToken, error) { func (d *Database) SyncPosition(ctx context.Context) (tok types.StreamingToken, err error) {
return d.syncPositionTx(ctx, nil) err = common.WithTransaction(d.DB, func(txn *sql.Tx) error {
pos, err := d.SyncPositionTx(ctx, txn)
if err != nil {
return err
}
tok = pos
return nil
})
return
} }
func (d *Database) BackwardExtremitiesForRoom( func (d *Database) BackwardExtremitiesForRoom(
@ -381,7 +389,8 @@ func (d *Database) EventPositionInTopology(
return d.Topology.SelectPositionInTopology(ctx, nil, eventID) return d.Topology.SelectPositionInTopology(ctx, nil, eventID)
} }
func (d *Database) syncPositionTx( // TODO FIXME TEMPORARY PUBLIC
func (d *Database) SyncPositionTx(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
) (sp types.StreamingToken, err error) { ) (sp types.StreamingToken, err error) {
@ -580,7 +589,7 @@ func (d *Database) getResponseWithPDUsForCompleteSync(
}() }()
// Get the current sync position which we will base the sync response on. // Get the current sync position which we will base the sync response on.
toPos, err = d.syncPositionTx(ctx, txn) toPos, err = d.SyncPositionTx(ctx, txn)
if err != nil { if err != nil {
return return
} }

View file

@ -123,81 +123,6 @@ func (d *SyncServerDatasource) prepare() (err error) {
return nil return nil
} }
// SyncPosition returns the latest positions for syncing.
func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (tok types.StreamingToken, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
pos, err := d.syncPositionTx(ctx, txn)
if err != nil {
return err
}
tok = *pos
return nil
})
return
}
// SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet.
func (d *SyncServerDatasource) SyncStreamPosition(ctx context.Context) (pos types.StreamPosition, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
pos, err = d.syncStreamPositionTx(ctx, txn)
return err
})
return
}
func (d *SyncServerDatasource) syncStreamPositionTx(
ctx context.Context, txn *sql.Tx,
) (types.StreamPosition, error) {
maxID, err := d.Database.OutputEvents.SelectMaxEventID(ctx, txn)
if err != nil {
return 0, err
}
maxAccountDataID, err := d.Database.AccountData.SelectMaxAccountDataID(ctx, txn)
if err != nil {
return 0, err
}
if maxAccountDataID > maxID {
maxID = maxAccountDataID
}
maxInviteID, err := d.Database.Invites.SelectMaxInviteID(ctx, txn)
if err != nil {
return 0, err
}
if maxInviteID > maxID {
maxID = maxInviteID
}
return types.StreamPosition(maxID), nil
}
func (d *SyncServerDatasource) syncPositionTx(
ctx context.Context, txn *sql.Tx,
) (*types.StreamingToken, error) {
maxEventID, err := d.Database.OutputEvents.SelectMaxEventID(ctx, txn)
if err != nil {
return nil, err
}
maxAccountDataID, err := d.Database.AccountData.SelectMaxAccountDataID(ctx, txn)
if err != nil {
return nil, err
}
if maxAccountDataID > maxEventID {
maxEventID = maxAccountDataID
}
maxInviteID, err := d.Database.Invites.SelectMaxInviteID(ctx, txn)
if err != nil {
return nil, err
}
if maxInviteID > maxEventID {
maxEventID = maxInviteID
}
sp := types.NewStreamToken(
types.StreamPosition(maxEventID),
types.StreamPosition(d.Database.EDUCache.GetLatestSyncPosition()),
)
return &sp, nil
}
// addPDUDeltaToResponse adds all PDU deltas to a sync response. // addPDUDeltaToResponse adds all PDU deltas to a sync response.
// IDs of all rooms the user joined are returned so EDU deltas can be added for them. // IDs of all rooms the user joined are returned so EDU deltas can be added for them.
func (d *SyncServerDatasource) addPDUDeltaToResponse( func (d *SyncServerDatasource) addPDUDeltaToResponse(
@ -356,7 +281,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
) ( ) (
res *types.Response, res *types.Response,
toPos *types.StreamingToken, toPos types.StreamingToken,
joinedRoomIDs []string, joinedRoomIDs []string,
err error, err error,
) { ) {
@ -377,12 +302,12 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
}() }()
// Get the current sync position which we will base the sync response on. // Get the current sync position which we will base the sync response on.
toPos, err = d.syncPositionTx(ctx, txn) toPos, err = d.Database.SyncPositionTx(ctx, txn)
if err != nil { if err != nil {
return return
} }
res = types.NewResponse(*toPos) res = types.NewResponse(toPos)
// Extract room state and recent events for all rooms the user is joined to. // Extract room state and recent events for all rooms the user is joined to.
joinedRoomIDs, err = d.Database.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) joinedRoomIDs, err = d.Database.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
@ -459,7 +384,7 @@ func (d *SyncServerDatasource) CompleteSync(
// Use a zero value SyncPosition for fromPos so all EDU states are added. // Use a zero value SyncPosition for fromPos so all EDU states are added.
err = d.addEDUDeltaToResponse( err = d.addEDUDeltaToResponse(
types.NewStreamToken(0, 0), *toPos, joinedRoomIDs, res, types.NewStreamToken(0, 0), toPos, joinedRoomIDs, res,
) )
if err != nil { if err != nil {
return nil, err return nil, err