mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-20 05:13:11 -06:00
Break out the retrieval from the update/delete behaviour
This commit is contained in:
parent
3ca5d8e21b
commit
ded130f548
|
|
@ -107,7 +107,9 @@ type Database interface {
|
|||
// SendToDeviceUpdatesForSync returns a list of send-to-device updates, after having completed
|
||||
// updates and deletions for previous events. The sync token should be supplied to this function so
|
||||
// that we can clean up old events properly.
|
||||
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) ([]types.SendToDeviceEvent, error)
|
||||
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) ([]types.SendToDeviceEvent, []types.SendToDeviceNID, []types.SendToDeviceNID, error)
|
||||
// StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device.
|
||||
StoreNewSendForDeviceMessage(ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
|
||||
// CleanSendToDeviceUpdates will update or remove any send-to-device updates based on the given sync.
|
||||
CleanSendToDeviceUpdates(ctx context.Context, toUpdate, toDelete []types.SendToDeviceNID, token types.StreamingToken) (err error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1063,19 +1063,20 @@ func (d *Database) SendToDeviceUpdatesForSync(
|
|||
ctx context.Context,
|
||||
userID, deviceID string,
|
||||
token types.StreamingToken,
|
||||
) (toReturn []types.SendToDeviceEvent, err error) {
|
||||
) ([]types.SendToDeviceEvent, []types.SendToDeviceNID, []types.SendToDeviceNID, error) {
|
||||
// First of all, get our send-to-device updates for this user.
|
||||
events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
|
||||
return nil, nil, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
|
||||
}
|
||||
|
||||
// If there's nothing to do then stop here.
|
||||
if len(events) == 0 {
|
||||
return nil, nil
|
||||
return nil, nil, nil, nil
|
||||
}
|
||||
|
||||
// Work out whether we need to update any of the database entries.
|
||||
toReturn := []types.SendToDeviceEvent{}
|
||||
toUpdate := []types.SendToDeviceNID{}
|
||||
toDelete := []types.SendToDeviceNID{}
|
||||
for _, event := range events {
|
||||
|
|
@ -1098,25 +1099,33 @@ func (d *Database) SendToDeviceUpdatesForSync(
|
|||
}
|
||||
}
|
||||
|
||||
return toReturn, toUpdate, toDelete, nil
|
||||
}
|
||||
|
||||
func (d *Database) CleanSendToDeviceUpdates(
|
||||
ctx context.Context,
|
||||
toUpdate, toDelete []types.SendToDeviceNID,
|
||||
token types.StreamingToken,
|
||||
) (err error) {
|
||||
if len(toUpdate) == 0 && len(toDelete) == 0 {
|
||||
return nil
|
||||
}
|
||||
// If we need to write to the database then we'll ask the SendToDeviceWriter to
|
||||
// do that for us. It'll guarantee that we don't lock the table for writes in
|
||||
// more than one place.
|
||||
if len(toUpdate) > 0 || len(toDelete) > 0 {
|
||||
d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) {
|
||||
// Delete any send-to-device messages marked for deletion.
|
||||
if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil {
|
||||
err = fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e)
|
||||
return
|
||||
}
|
||||
|
||||
// Now update any outstanding send-to-device messages with the new sync token.
|
||||
if e := d.SendToDevice.UpdateSentSendToDeviceMessages(ctx, txn, token.String(), toUpdate); e != nil {
|
||||
err = fmt.Errorf("d.SendToDevice.UpdateSentSendToDeviceMessages: %w", err)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) {
|
||||
// Delete any send-to-device messages marked for deletion.
|
||||
if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil {
|
||||
err = fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e)
|
||||
return
|
||||
}
|
||||
|
||||
// Now update any outstanding send-to-device messages with the new sync token.
|
||||
if e := d.SendToDevice.UpdateSentSendToDeviceMessages(ctx, txn, token.String(), toUpdate); e != nil {
|
||||
err = fmt.Errorf("d.SendToDevice.UpdateSentSendToDeviceMessages: %w", err)
|
||||
return
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -524,13 +524,17 @@ func TestSendToDeviceBehaviour(t *testing.T) {
|
|||
|
||||
// At this point there should be no messages. We haven't sent anything
|
||||
// yet.
|
||||
first, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0))
|
||||
events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(first) != 0 {
|
||||
if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 {
|
||||
t.Fatal("first call should have no updates")
|
||||
}
|
||||
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, 0))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Try sending a message.
|
||||
streamPos, err := db.StoreNewSendForDeviceMessage(ctx, "alice", "one", gomatrixserverlib.SendToDeviceEvent{
|
||||
|
|
@ -545,42 +549,54 @@ func TestSendToDeviceBehaviour(t *testing.T) {
|
|||
// At this point we should get exactly one message. We're sending the sync position
|
||||
// that we were given from the update and the send-to-device update will be updated
|
||||
// in the database to reflect that this was the sync position we sent the message at.
|
||||
second, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos))
|
||||
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(second) != 1 {
|
||||
if len(events) != 1 || len(updates) != 1 || len(deletions) != 0 {
|
||||
t.Fatal("second call should have one update")
|
||||
}
|
||||
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// At this point we should still have one message because we haven't progressed the
|
||||
// sync position yet. This is equivalent to the client failing to /sync and retrying
|
||||
// with the same position.
|
||||
third, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos))
|
||||
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(third) != 1 {
|
||||
if len(events) != 1 || len(updates) != 0 || len(deletions) != 0 {
|
||||
t.Fatal("third call should have one update still")
|
||||
}
|
||||
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// At this point we should now have no updates, because we've progressed the sync
|
||||
// position. Therefore the update from before will not be sent again.
|
||||
fourth, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+1))
|
||||
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+1))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(fourth) != 0 {
|
||||
if len(events) != 0 || len(updates) != 0 || len(deletions) != 1 {
|
||||
t.Fatal("fourth call should have no updates")
|
||||
}
|
||||
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos+1))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// At this point we should still have no updates, because no new updates have been
|
||||
// sent.
|
||||
fifth, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+2))
|
||||
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+2))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(fifth) != 0 {
|
||||
if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 {
|
||||
t.Fatal("fifth call should have no updates")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -137,22 +137,35 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
|
|||
|
||||
func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) {
|
||||
res = types.NewResponse()
|
||||
|
||||
res, err = rp.appendSendToDeviceMessages(res, req.device.UserID, req, latestPos)
|
||||
events, updates, deletions, err := rp.db.SendToDeviceUpdatesForSync(req.ctx, req.device.UserID, req.device.ID, latestPos)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(res.ToDevice.Events) > 0 {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if len(updates) > 0 || len(deletions) > 0 {
|
||||
err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, latestPos)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if len(events) > 0 {
|
||||
for _, event := range events {
|
||||
res.ToDevice.Events = append(res.ToDevice.Events, event.SendToDeviceEvent)
|
||||
}
|
||||
if pos, perr := types.NewStreamTokenFromString(res.NextBatch); perr == nil {
|
||||
pos.Positions[1]++
|
||||
res.NextBatch = pos.String()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// TODO: handle ignored users
|
||||
if req.since == nil {
|
||||
res, err = rp.db.CompleteSync(req.ctx, res, req.device, req.limit)
|
||||
} else {
|
||||
res, err = rp.db.IncrementalSync(req.ctx, res, req.device, *req.since, latestPos, req.limit, req.wantFullState)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -247,31 +260,6 @@ func (rp *RequestPool) appendAccountData(
|
|||
return data, nil
|
||||
}
|
||||
|
||||
func (rp *RequestPool) appendSendToDeviceMessages(
|
||||
data *types.Response, userID string, req syncRequest, currentPos types.StreamingToken,
|
||||
) (*types.Response, error) {
|
||||
events, err := rp.db.SendToDeviceUpdatesForSync(
|
||||
context.TODO(),
|
||||
userID,
|
||||
req.device.ID,
|
||||
currentPos,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, event := range events {
|
||||
data.ToDevice.Events = append(data.ToDevice.Events, event.SendToDeviceEvent)
|
||||
}
|
||||
|
||||
if len(data.ToDevice.Events) > 0 {
|
||||
currentPos.Positions[1]++
|
||||
data.NextBatch = currentPos.String()
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// shouldReturnImmediately returns whether the /sync request is an initial sync,
|
||||
// or timeout=0, or full_state=true, in any of the cases the request should
|
||||
// return immediately.
|
||||
|
|
|
|||
Loading…
Reference in a new issue