Break out the retrieval from the update/delete behaviour

This commit is contained in:
Neil Alexander 2020-05-29 15:41:04 +01:00
parent 3ca5d8e21b
commit ded130f548
4 changed files with 76 additions and 61 deletions

View file

@ -107,7 +107,9 @@ type Database interface {
// SendToDeviceUpdatesForSync returns a list of send-to-device updates, after having completed
// updates and deletions for previous events. The sync token should be supplied to this function so
// that we can clean up old events properly.
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) ([]types.SendToDeviceEvent, error)
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) ([]types.SendToDeviceEvent, []types.SendToDeviceNID, []types.SendToDeviceNID, error)
// StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device.
StoreNewSendForDeviceMessage(ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
// CleanSendToDeviceUpdates will update or remove any send-to-device updates based on the given sync.
CleanSendToDeviceUpdates(ctx context.Context, toUpdate, toDelete []types.SendToDeviceNID, token types.StreamingToken) (err error)
}

View file

@ -1063,19 +1063,20 @@ func (d *Database) SendToDeviceUpdatesForSync(
ctx context.Context,
userID, deviceID string,
token types.StreamingToken,
) (toReturn []types.SendToDeviceEvent, err error) {
) ([]types.SendToDeviceEvent, []types.SendToDeviceNID, []types.SendToDeviceNID, error) {
// First of all, get our send-to-device updates for this user.
events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID)
if err != nil {
return nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
return nil, nil, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
}
// If there's nothing to do then stop here.
if len(events) == 0 {
return nil, nil
return nil, nil, nil, nil
}
// Work out whether we need to update any of the database entries.
toReturn := []types.SendToDeviceEvent{}
toUpdate := []types.SendToDeviceNID{}
toDelete := []types.SendToDeviceNID{}
for _, event := range events {
@ -1098,25 +1099,33 @@ func (d *Database) SendToDeviceUpdatesForSync(
}
}
return toReturn, toUpdate, toDelete, nil
}
func (d *Database) CleanSendToDeviceUpdates(
ctx context.Context,
toUpdate, toDelete []types.SendToDeviceNID,
token types.StreamingToken,
) (err error) {
if len(toUpdate) == 0 && len(toDelete) == 0 {
return nil
}
// If we need to write to the database then we'll ask the SendToDeviceWriter to
// do that for us. It'll guarantee that we don't lock the table for writes in
// more than one place.
if len(toUpdate) > 0 || len(toDelete) > 0 {
d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) {
// Delete any send-to-device messages marked for deletion.
if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil {
err = fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e)
return
}
// Now update any outstanding send-to-device messages with the new sync token.
if e := d.SendToDevice.UpdateSentSendToDeviceMessages(ctx, txn, token.String(), toUpdate); e != nil {
err = fmt.Errorf("d.SendToDevice.UpdateSentSendToDeviceMessages: %w", err)
return
}
})
}
d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) {
// Delete any send-to-device messages marked for deletion.
if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil {
err = fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e)
return
}
// Now update any outstanding send-to-device messages with the new sync token.
if e := d.SendToDevice.UpdateSentSendToDeviceMessages(ctx, txn, token.String(), toUpdate); e != nil {
err = fmt.Errorf("d.SendToDevice.UpdateSentSendToDeviceMessages: %w", err)
return
}
})
return
}

View file

@ -524,13 +524,17 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point there should be no messages. We haven't sent anything
// yet.
first, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0))
events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0))
if err != nil {
t.Fatal(err)
}
if len(first) != 0 {
if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 {
t.Fatal("first call should have no updates")
}
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, 0))
if err != nil {
return
}
// Try sending a message.
streamPos, err := db.StoreNewSendForDeviceMessage(ctx, "alice", "one", gomatrixserverlib.SendToDeviceEvent{
@ -545,42 +549,54 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point we should get exactly one message. We're sending the sync position
// that we were given from the update and the send-to-device update will be updated
// in the database to reflect that this was the sync position we sent the message at.
second, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos))
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos))
if err != nil {
t.Fatal(err)
}
if len(second) != 1 {
if len(events) != 1 || len(updates) != 1 || len(deletions) != 0 {
t.Fatal("second call should have one update")
}
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos))
if err != nil {
return
}
// At this point we should still have one message because we haven't progressed the
// sync position yet. This is equivalent to the client failing to /sync and retrying
// with the same position.
third, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos))
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos))
if err != nil {
t.Fatal(err)
}
if len(third) != 1 {
if len(events) != 1 || len(updates) != 0 || len(deletions) != 0 {
t.Fatal("third call should have one update still")
}
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos))
if err != nil {
return
}
// At this point we should now have no updates, because we've progressed the sync
// position. Therefore the update from before will not be sent again.
fourth, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+1))
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+1))
if err != nil {
t.Fatal(err)
}
if len(fourth) != 0 {
if len(events) != 0 || len(updates) != 0 || len(deletions) != 1 {
t.Fatal("fourth call should have no updates")
}
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos+1))
if err != nil {
return
}
// At this point we should still have no updates, because no new updates have been
// sent.
fifth, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+2))
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+2))
if err != nil {
t.Fatal(err)
}
if len(fifth) != 0 {
if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 {
t.Fatal("fifth call should have no updates")
}
}

View file

@ -137,22 +137,35 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) {
res = types.NewResponse()
res, err = rp.appendSendToDeviceMessages(res, req.device.UserID, req, latestPos)
events, updates, deletions, err := rp.db.SendToDeviceUpdatesForSync(req.ctx, req.device.UserID, req.device.ID, latestPos)
if err != nil {
return
}
if len(res.ToDevice.Events) > 0 {
return
return nil, err
}
defer func() {
if len(updates) > 0 || len(deletions) > 0 {
err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, latestPos)
if err != nil {
return
}
}
if len(events) > 0 {
for _, event := range events {
res.ToDevice.Events = append(res.ToDevice.Events, event.SendToDeviceEvent)
}
if pos, perr := types.NewStreamTokenFromString(res.NextBatch); perr == nil {
pos.Positions[1]++
res.NextBatch = pos.String()
}
}
}()
// TODO: handle ignored users
if req.since == nil {
res, err = rp.db.CompleteSync(req.ctx, res, req.device, req.limit)
} else {
res, err = rp.db.IncrementalSync(req.ctx, res, req.device, *req.since, latestPos, req.limit, req.wantFullState)
}
if err != nil {
return
}
@ -247,31 +260,6 @@ func (rp *RequestPool) appendAccountData(
return data, nil
}
func (rp *RequestPool) appendSendToDeviceMessages(
data *types.Response, userID string, req syncRequest, currentPos types.StreamingToken,
) (*types.Response, error) {
events, err := rp.db.SendToDeviceUpdatesForSync(
context.TODO(),
userID,
req.device.ID,
currentPos,
)
if err != nil {
return nil, err
}
for _, event := range events {
data.ToDevice.Events = append(data.ToDevice.Events, event.SendToDeviceEvent)
}
if len(data.ToDevice.Events) > 0 {
currentPos.Positions[1]++
data.NextBatch = currentPos.String()
}
return data, nil
}
// shouldReturnImmediately returns whether the /sync request is an initial sync,
// or timeout=0, or full_state=true, in any of the cases the request should
// return immediately.