mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-20 21:33:19 -06:00
Break out the retrieval from the update/delete behaviour
This commit is contained in:
parent
3ca5d8e21b
commit
ded130f548
|
|
@ -107,7 +107,9 @@ type Database interface {
|
||||||
// SendToDeviceUpdatesForSync returns a list of send-to-device updates, after having completed
|
// SendToDeviceUpdatesForSync returns a list of send-to-device updates, after having completed
|
||||||
// updates and deletions for previous events. The sync token should be supplied to this function so
|
// updates and deletions for previous events. The sync token should be supplied to this function so
|
||||||
// that we can clean up old events properly.
|
// that we can clean up old events properly.
|
||||||
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) ([]types.SendToDeviceEvent, error)
|
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) ([]types.SendToDeviceEvent, []types.SendToDeviceNID, []types.SendToDeviceNID, error)
|
||||||
// StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device.
|
// StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device.
|
||||||
StoreNewSendForDeviceMessage(ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
|
StoreNewSendForDeviceMessage(ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
|
||||||
|
// CleanSendToDeviceUpdates will update or remove any send-to-device updates based on the given sync.
|
||||||
|
CleanSendToDeviceUpdates(ctx context.Context, toUpdate, toDelete []types.SendToDeviceNID, token types.StreamingToken) (err error)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1063,19 +1063,20 @@ func (d *Database) SendToDeviceUpdatesForSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
userID, deviceID string,
|
userID, deviceID string,
|
||||||
token types.StreamingToken,
|
token types.StreamingToken,
|
||||||
) (toReturn []types.SendToDeviceEvent, err error) {
|
) ([]types.SendToDeviceEvent, []types.SendToDeviceNID, []types.SendToDeviceNID, error) {
|
||||||
// First of all, get our send-to-device updates for this user.
|
// First of all, get our send-to-device updates for this user.
|
||||||
events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID)
|
events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
|
return nil, nil, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there's nothing to do then stop here.
|
// If there's nothing to do then stop here.
|
||||||
if len(events) == 0 {
|
if len(events) == 0 {
|
||||||
return nil, nil
|
return nil, nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Work out whether we need to update any of the database entries.
|
// Work out whether we need to update any of the database entries.
|
||||||
|
toReturn := []types.SendToDeviceEvent{}
|
||||||
toUpdate := []types.SendToDeviceNID{}
|
toUpdate := []types.SendToDeviceNID{}
|
||||||
toDelete := []types.SendToDeviceNID{}
|
toDelete := []types.SendToDeviceNID{}
|
||||||
for _, event := range events {
|
for _, event := range events {
|
||||||
|
|
@ -1098,25 +1099,33 @@ func (d *Database) SendToDeviceUpdatesForSync(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return toReturn, toUpdate, toDelete, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) CleanSendToDeviceUpdates(
|
||||||
|
ctx context.Context,
|
||||||
|
toUpdate, toDelete []types.SendToDeviceNID,
|
||||||
|
token types.StreamingToken,
|
||||||
|
) (err error) {
|
||||||
|
if len(toUpdate) == 0 && len(toDelete) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
// If we need to write to the database then we'll ask the SendToDeviceWriter to
|
// If we need to write to the database then we'll ask the SendToDeviceWriter to
|
||||||
// do that for us. It'll guarantee that we don't lock the table for writes in
|
// do that for us. It'll guarantee that we don't lock the table for writes in
|
||||||
// more than one place.
|
// more than one place.
|
||||||
if len(toUpdate) > 0 || len(toDelete) > 0 {
|
d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) {
|
||||||
d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) {
|
// Delete any send-to-device messages marked for deletion.
|
||||||
// Delete any send-to-device messages marked for deletion.
|
if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil {
|
||||||
if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil {
|
err = fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e)
|
||||||
err = fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e)
|
return
|
||||||
return
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Now update any outstanding send-to-device messages with the new sync token.
|
|
||||||
if e := d.SendToDevice.UpdateSentSendToDeviceMessages(ctx, txn, token.String(), toUpdate); e != nil {
|
|
||||||
err = fmt.Errorf("d.SendToDevice.UpdateSentSendToDeviceMessages: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// Now update any outstanding send-to-device messages with the new sync token.
|
||||||
|
if e := d.SendToDevice.UpdateSentSendToDeviceMessages(ctx, txn, token.String(), toUpdate); e != nil {
|
||||||
|
err = fmt.Errorf("d.SendToDevice.UpdateSentSendToDeviceMessages: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -524,13 +524,17 @@ func TestSendToDeviceBehaviour(t *testing.T) {
|
||||||
|
|
||||||
// At this point there should be no messages. We haven't sent anything
|
// At this point there should be no messages. We haven't sent anything
|
||||||
// yet.
|
// yet.
|
||||||
first, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0))
|
events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if len(first) != 0 {
|
if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 {
|
||||||
t.Fatal("first call should have no updates")
|
t.Fatal("first call should have no updates")
|
||||||
}
|
}
|
||||||
|
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, 0))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Try sending a message.
|
// Try sending a message.
|
||||||
streamPos, err := db.StoreNewSendForDeviceMessage(ctx, "alice", "one", gomatrixserverlib.SendToDeviceEvent{
|
streamPos, err := db.StoreNewSendForDeviceMessage(ctx, "alice", "one", gomatrixserverlib.SendToDeviceEvent{
|
||||||
|
|
@ -545,42 +549,54 @@ func TestSendToDeviceBehaviour(t *testing.T) {
|
||||||
// At this point we should get exactly one message. We're sending the sync position
|
// At this point we should get exactly one message. We're sending the sync position
|
||||||
// that we were given from the update and the send-to-device update will be updated
|
// that we were given from the update and the send-to-device update will be updated
|
||||||
// in the database to reflect that this was the sync position we sent the message at.
|
// in the database to reflect that this was the sync position we sent the message at.
|
||||||
second, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos))
|
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if len(second) != 1 {
|
if len(events) != 1 || len(updates) != 1 || len(deletions) != 0 {
|
||||||
t.Fatal("second call should have one update")
|
t.Fatal("second call should have one update")
|
||||||
}
|
}
|
||||||
|
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// At this point we should still have one message because we haven't progressed the
|
// At this point we should still have one message because we haven't progressed the
|
||||||
// sync position yet. This is equivalent to the client failing to /sync and retrying
|
// sync position yet. This is equivalent to the client failing to /sync and retrying
|
||||||
// with the same position.
|
// with the same position.
|
||||||
third, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos))
|
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if len(third) != 1 {
|
if len(events) != 1 || len(updates) != 0 || len(deletions) != 0 {
|
||||||
t.Fatal("third call should have one update still")
|
t.Fatal("third call should have one update still")
|
||||||
}
|
}
|
||||||
|
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// At this point we should now have no updates, because we've progressed the sync
|
// At this point we should now have no updates, because we've progressed the sync
|
||||||
// position. Therefore the update from before will not be sent again.
|
// position. Therefore the update from before will not be sent again.
|
||||||
fourth, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+1))
|
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if len(fourth) != 0 {
|
if len(events) != 0 || len(updates) != 0 || len(deletions) != 1 {
|
||||||
t.Fatal("fourth call should have no updates")
|
t.Fatal("fourth call should have no updates")
|
||||||
}
|
}
|
||||||
|
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos+1))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// At this point we should still have no updates, because no new updates have been
|
// At this point we should still have no updates, because no new updates have been
|
||||||
// sent.
|
// sent.
|
||||||
fifth, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+2))
|
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+2))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if len(fifth) != 0 {
|
if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 {
|
||||||
t.Fatal("fifth call should have no updates")
|
t.Fatal("fifth call should have no updates")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -137,22 +137,35 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
|
||||||
|
|
||||||
func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) {
|
func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) {
|
||||||
res = types.NewResponse()
|
res = types.NewResponse()
|
||||||
|
events, updates, deletions, err := rp.db.SendToDeviceUpdatesForSync(req.ctx, req.device.UserID, req.device.ID, latestPos)
|
||||||
res, err = rp.appendSendToDeviceMessages(res, req.device.UserID, req, latestPos)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
|
||||||
if len(res.ToDevice.Events) > 0 {
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if len(updates) > 0 || len(deletions) > 0 {
|
||||||
|
err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, latestPos)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(events) > 0 {
|
||||||
|
for _, event := range events {
|
||||||
|
res.ToDevice.Events = append(res.ToDevice.Events, event.SendToDeviceEvent)
|
||||||
|
}
|
||||||
|
if pos, perr := types.NewStreamTokenFromString(res.NextBatch); perr == nil {
|
||||||
|
pos.Positions[1]++
|
||||||
|
res.NextBatch = pos.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// TODO: handle ignored users
|
// TODO: handle ignored users
|
||||||
if req.since == nil {
|
if req.since == nil {
|
||||||
res, err = rp.db.CompleteSync(req.ctx, res, req.device, req.limit)
|
res, err = rp.db.CompleteSync(req.ctx, res, req.device, req.limit)
|
||||||
} else {
|
} else {
|
||||||
res, err = rp.db.IncrementalSync(req.ctx, res, req.device, *req.since, latestPos, req.limit, req.wantFullState)
|
res, err = rp.db.IncrementalSync(req.ctx, res, req.device, *req.since, latestPos, req.limit, req.wantFullState)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -247,31 +260,6 @@ func (rp *RequestPool) appendAccountData(
|
||||||
return data, nil
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *RequestPool) appendSendToDeviceMessages(
|
|
||||||
data *types.Response, userID string, req syncRequest, currentPos types.StreamingToken,
|
|
||||||
) (*types.Response, error) {
|
|
||||||
events, err := rp.db.SendToDeviceUpdatesForSync(
|
|
||||||
context.TODO(),
|
|
||||||
userID,
|
|
||||||
req.device.ID,
|
|
||||||
currentPos,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, event := range events {
|
|
||||||
data.ToDevice.Events = append(data.ToDevice.Events, event.SendToDeviceEvent)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(data.ToDevice.Events) > 0 {
|
|
||||||
currentPos.Positions[1]++
|
|
||||||
data.NextBatch = currentPos.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// shouldReturnImmediately returns whether the /sync request is an initial sync,
|
// shouldReturnImmediately returns whether the /sync request is an initial sync,
|
||||||
// or timeout=0, or full_state=true, in any of the cases the request should
|
// or timeout=0, or full_state=true, in any of the cases the request should
|
||||||
// return immediately.
|
// return immediately.
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue