From ded130f548f4fc4acc92d182c0ede7dc91c9856e Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 29 May 2020 15:41:04 +0100 Subject: [PATCH] Break out the retrieval from the update/delete behaviour --- syncapi/storage/interface.go | 4 ++- syncapi/storage/shared/syncserver.go | 45 ++++++++++++++---------- syncapi/storage/storage_test.go | 36 +++++++++++++------ syncapi/sync/requestpool.go | 52 +++++++++++----------------- 4 files changed, 76 insertions(+), 61 deletions(-) diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 7378d3bd9..ec0f4de39 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -107,7 +107,9 @@ type Database interface { // SendToDeviceUpdatesForSync returns a list of send-to-device updates, after having completed // updates and deletions for previous events. The sync token should be supplied to this function so // that we can clean up old events properly. - SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) ([]types.SendToDeviceEvent, error) + SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) ([]types.SendToDeviceEvent, []types.SendToDeviceNID, []types.SendToDeviceNID, error) // StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device. StoreNewSendForDeviceMessage(ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error) + // CleanSendToDeviceUpdates will update or remove any send-to-device updates based on the given sync. + CleanSendToDeviceUpdates(ctx context.Context, toUpdate, toDelete []types.SendToDeviceNID, token types.StreamingToken) (err error) } diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index ea5b2ad24..e222b9c83 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -1063,19 +1063,20 @@ func (d *Database) SendToDeviceUpdatesForSync( ctx context.Context, userID, deviceID string, token types.StreamingToken, -) (toReturn []types.SendToDeviceEvent, err error) { +) ([]types.SendToDeviceEvent, []types.SendToDeviceNID, []types.SendToDeviceNID, error) { // First of all, get our send-to-device updates for this user. events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID) if err != nil { - return nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err) + return nil, nil, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err) } // If there's nothing to do then stop here. if len(events) == 0 { - return nil, nil + return nil, nil, nil, nil } // Work out whether we need to update any of the database entries. + toReturn := []types.SendToDeviceEvent{} toUpdate := []types.SendToDeviceNID{} toDelete := []types.SendToDeviceNID{} for _, event := range events { @@ -1098,25 +1099,33 @@ func (d *Database) SendToDeviceUpdatesForSync( } } + return toReturn, toUpdate, toDelete, nil +} + +func (d *Database) CleanSendToDeviceUpdates( + ctx context.Context, + toUpdate, toDelete []types.SendToDeviceNID, + token types.StreamingToken, +) (err error) { + if len(toUpdate) == 0 && len(toDelete) == 0 { + return nil + } // If we need to write to the database then we'll ask the SendToDeviceWriter to // do that for us. It'll guarantee that we don't lock the table for writes in // more than one place. - if len(toUpdate) > 0 || len(toDelete) > 0 { - d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) { - // Delete any send-to-device messages marked for deletion. - if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil { - err = fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e) - return - } - - // Now update any outstanding send-to-device messages with the new sync token. - if e := d.SendToDevice.UpdateSentSendToDeviceMessages(ctx, txn, token.String(), toUpdate); e != nil { - err = fmt.Errorf("d.SendToDevice.UpdateSentSendToDeviceMessages: %w", err) - return - } - }) - } + d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) { + // Delete any send-to-device messages marked for deletion. + if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil { + err = fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e) + return + } + // Now update any outstanding send-to-device messages with the new sync token. + if e := d.SendToDevice.UpdateSentSendToDeviceMessages(ctx, txn, token.String(), toUpdate); e != nil { + err = fmt.Errorf("d.SendToDevice.UpdateSentSendToDeviceMessages: %w", err) + return + } + }) return } diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 9cc487236..792ba1c9d 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -524,13 +524,17 @@ func TestSendToDeviceBehaviour(t *testing.T) { // At this point there should be no messages. We haven't sent anything // yet. - first, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0)) + events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0)) if err != nil { t.Fatal(err) } - if len(first) != 0 { + if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 { t.Fatal("first call should have no updates") } + err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, 0)) + if err != nil { + return + } // Try sending a message. streamPos, err := db.StoreNewSendForDeviceMessage(ctx, "alice", "one", gomatrixserverlib.SendToDeviceEvent{ @@ -545,42 +549,54 @@ func TestSendToDeviceBehaviour(t *testing.T) { // At this point we should get exactly one message. We're sending the sync position // that we were given from the update and the send-to-device update will be updated // in the database to reflect that this was the sync position we sent the message at. - second, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos)) + events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos)) if err != nil { t.Fatal(err) } - if len(second) != 1 { + if len(events) != 1 || len(updates) != 1 || len(deletions) != 0 { t.Fatal("second call should have one update") } + err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos)) + if err != nil { + return + } // At this point we should still have one message because we haven't progressed the // sync position yet. This is equivalent to the client failing to /sync and retrying // with the same position. - third, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos)) + events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos)) if err != nil { t.Fatal(err) } - if len(third) != 1 { + if len(events) != 1 || len(updates) != 0 || len(deletions) != 0 { t.Fatal("third call should have one update still") } + err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos)) + if err != nil { + return + } // At this point we should now have no updates, because we've progressed the sync // position. Therefore the update from before will not be sent again. - fourth, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+1)) + events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+1)) if err != nil { t.Fatal(err) } - if len(fourth) != 0 { + if len(events) != 0 || len(updates) != 0 || len(deletions) != 1 { t.Fatal("fourth call should have no updates") } + err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos+1)) + if err != nil { + return + } // At this point we should still have no updates, because no new updates have been // sent. - fifth, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+2)) + events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+2)) if err != nil { t.Fatal(err) } - if len(fifth) != 0 { + if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 { t.Fatal("fifth call should have no updates") } } diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index bf5dcd896..4fb6927d3 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -137,22 +137,35 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) { res = types.NewResponse() - - res, err = rp.appendSendToDeviceMessages(res, req.device.UserID, req, latestPos) + events, updates, deletions, err := rp.db.SendToDeviceUpdatesForSync(req.ctx, req.device.UserID, req.device.ID, latestPos) if err != nil { - return - } - if len(res.ToDevice.Events) > 0 { - return + return nil, err } + defer func() { + if len(updates) > 0 || len(deletions) > 0 { + err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, latestPos) + if err != nil { + return + } + } + if len(events) > 0 { + for _, event := range events { + res.ToDevice.Events = append(res.ToDevice.Events, event.SendToDeviceEvent) + } + if pos, perr := types.NewStreamTokenFromString(res.NextBatch); perr == nil { + pos.Positions[1]++ + res.NextBatch = pos.String() + } + } + }() + // TODO: handle ignored users if req.since == nil { res, err = rp.db.CompleteSync(req.ctx, res, req.device, req.limit) } else { res, err = rp.db.IncrementalSync(req.ctx, res, req.device, *req.since, latestPos, req.limit, req.wantFullState) } - if err != nil { return } @@ -247,31 +260,6 @@ func (rp *RequestPool) appendAccountData( return data, nil } -func (rp *RequestPool) appendSendToDeviceMessages( - data *types.Response, userID string, req syncRequest, currentPos types.StreamingToken, -) (*types.Response, error) { - events, err := rp.db.SendToDeviceUpdatesForSync( - context.TODO(), - userID, - req.device.ID, - currentPos, - ) - if err != nil { - return nil, err - } - - for _, event := range events { - data.ToDevice.Events = append(data.ToDevice.Events, event.SendToDeviceEvent) - } - - if len(data.ToDevice.Events) > 0 { - currentPos.Positions[1]++ - data.NextBatch = currentPos.String() - } - - return data, nil -} - // shouldReturnImmediately returns whether the /sync request is an initial sync, // or timeout=0, or full_state=true, in any of the cases the request should // return immediately.