From 5389a629528ddb042636cf6a5f84caab4b911996 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 28 May 2020 17:49:56 +0100 Subject: [PATCH] Add send-to-device test, hopefully fix bugs --- .../storage/postgres/send_to_device_table.go | 1 + syncapi/storage/shared/syncserver.go | 12 ++-- .../storage/sqlite3/send_to_device_table.go | 3 +- syncapi/storage/storage_test.go | 59 ++++++++++++++++++- 4 files changed, 66 insertions(+), 9 deletions(-) diff --git a/syncapi/storage/postgres/send_to_device_table.go b/syncapi/storage/postgres/send_to_device_table.go index 5e8410949..b9e682eb8 100644 --- a/syncapi/storage/postgres/send_to_device_table.go +++ b/syncapi/storage/postgres/send_to_device_table.go @@ -119,6 +119,7 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages( return } event := types.SendToDeviceEvent{ + ID: id, SendToDeviceEvent: gomatrixserverlib.SendToDeviceEvent{ UserID: userID, DeviceID: deviceID, diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 496d62404..8ffbe6498 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -1057,7 +1057,7 @@ func (d *Database) SendToDeviceUpdatesForSync( ) (events []types.SendToDeviceEvent, err error) { err = internal.WithTransaction(d.DB, func(txn *sql.Tx) error { // First of all, get our send-to-device updates for this user. - events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, txn, userID, deviceID) + events, err = d.SendToDevice.SelectSendToDeviceMessages(ctx, txn, userID, deviceID) if err != nil { return fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err) } @@ -1068,14 +1068,14 @@ func (d *Database) SendToDeviceUpdatesForSync( toUpdate := []types.SendToDeviceNID{} toDelete := []types.SendToDeviceNID{} for pos, event := range events { - if event.SentByToken != nil && token.IsAfter(*event.SentByToken) { - // Mark the event for deletion and remove it from our list of return events. - toDelete = append(toDelete, event.ID) - events = append(events[:pos], events[pos+1:]...) - } else { + if event.SentByToken == nil { // Mark the event for update and keep it in our list of return events. toUpdate = append(toUpdate, event.ID) event.SentByToken = &token + } else if token.IsAfter(*event.SentByToken) { + // Mark the event for deletion and remove it from our list of return events. + toDelete = append(toDelete, event.ID) + events = append(events[:pos], events[pos+1:]...) } } diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go index 7cfc45b51..2bcb8f62f 100644 --- a/syncapi/storage/sqlite3/send_to_device_table.go +++ b/syncapi/storage/sqlite3/send_to_device_table.go @@ -109,6 +109,7 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages( return } event := types.SendToDeviceEvent{ + ID: id, SendToDeviceEvent: gomatrixserverlib.SendToDeviceEvent{ UserID: userID, DeviceID: deviceID, @@ -130,7 +131,7 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages( func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID, ) (err error) { - query := strings.Replace(updateSentSendToDeviceMessagesSQL, "($2)", internal.QueryVariadic(len(nids)), 1) + query := strings.Replace(updateSentSendToDeviceMessagesSQL, "($2)", internal.QueryVariadic(1+len(nids)), 1) params := make([]interface{}, 1+len(nids)) params[0] = token for k, v := range nids { diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 1d35e18ca..88efd8f04 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -3,6 +3,7 @@ package storage_test import ( "context" "crypto/ed25519" + "encoding/json" "fmt" "testing" "time" @@ -516,11 +517,65 @@ func TestSendToDeviceBehaviour(t *testing.T) { //t.Parallel() db := MustCreateDatabase(t) - initial, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0)) + // At this point there should be no messages. We haven't sent anything + // yet. + first, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0)) if err != nil { t.Fatal(err) } - fmt.Println("Initial:", initial) + if len(first) != 0 { + t.Fatal("first call should have no updates") + } + + // Try sending a message. + streamPos, err := db.StoreNewSendForDeviceMessage(ctx, gomatrixserverlib.SendToDeviceEvent{ + UserID: "alice", + DeviceID: "one", + EventType: "m.type", + Message: json.RawMessage("{}"), + }) + if err != nil { + t.Fatal(err) + } + + // At this point we should get exactly one message. We're sending the sync position + // that we were given from the update. + second, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos)) + if err != nil { + t.Fatal(err) + } + if len(second) != 1 { + t.Fatal("second call should have one update") + } + + // At this point we should still have one message because we haven't progressed the + // sync position yet. This is equivalent to the client failing to /sync and retrying + // with the same position. + third, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos)) + if err != nil { + t.Fatal(err) + } + if len(third) != 1 { + t.Fatal("third call should have one update still") + } + + // At this point we should now have no updates, because we've progressed the sync + // position. Therefore the update from before will be cleane + fourth, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+1)) + if err != nil { + t.Fatal(err) + } + if len(fourth) != 0 { + t.Fatal("fourth call should have no updates") + } + + fifth, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+2)) + if err != nil { + t.Fatal(err) + } + if len(fifth) != 0 { + t.Fatal("fifth call should have no updates") + } } func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatrixserverlib.ClientEvent, wants []gomatrixserverlib.HeaderedEvent) {