diff --git a/clientapi/routing/sendtodevice.go b/clientapi/routing/sendtodevice.go index abcf9e162..6657fd103 100644 --- a/clientapi/routing/sendtodevice.go +++ b/clientapi/routing/sendtodevice.go @@ -43,7 +43,7 @@ func SendToDevice( var httpReq struct { Messages map[string]map[string]json.RawMessage `json:"messages"` } - resErr := httputil.UnmarshalJSONRequest(req, &req) + resErr := httputil.UnmarshalJSONRequest(req, &httpReq) if resErr != nil { return *resErr } diff --git a/eduserver/input/input.go b/eduserver/input/input.go index a749f3aa5..ed9ef18f0 100644 --- a/eduserver/input/input.go +++ b/eduserver/input/input.go @@ -117,7 +117,7 @@ func (t *EDUServerInputAPI) sendToDeviceEvent(ise *api.InputSendToDeviceEvent) e "user_id": ise.UserID, "device_id": ise.DeviceID, "event_type": ise.EventType, - }).Error("sendToDevice") + }).Info("handling send-to-device message") eventJSON, err := json.Marshal(ote) if err != nil { diff --git a/syncapi/storage/postgres/send_to_device_table.go b/syncapi/storage/postgres/send_to_device_table.go index bb9c20b58..5e8410949 100644 --- a/syncapi/storage/postgres/send_to_device_table.go +++ b/syncapi/storage/postgres/send_to_device_table.go @@ -98,14 +98,14 @@ func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { func (s *sendToDeviceStatements) InsertSendToDeviceMessage( ctx context.Context, txn *sql.Tx, userID, deviceID, eventType, content string, ) (err error) { - _, err = txn.Stmt(s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, eventType, content) + _, err = internal.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, eventType, content) return } func (s *sendToDeviceStatements) SelectSendToDeviceMessages( - ctx context.Context, userID, deviceID string, + ctx context.Context, txn *sql.Tx, userID, deviceID string, ) (events []types.SendToDeviceEvent, err error) { - rows, err := s.selectSendToDeviceMessagesStmt.QueryContext(ctx, userID, deviceID) + rows, err := internal.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID) if err != nil { return } diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 023fe8bf9..496d62404 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -1057,7 +1057,7 @@ func (d *Database) SendToDeviceUpdatesForSync( ) (events []types.SendToDeviceEvent, err error) { err = internal.WithTransaction(d.DB, func(txn *sql.Tx) error { // First of all, get our send-to-device updates for this user. - events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, userID, deviceID) + events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, txn, userID, deviceID) if err != nil { return fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err) } diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go index db02f40c2..7cfc45b51 100644 --- a/syncapi/storage/sqlite3/send_to_device_table.go +++ b/syncapi/storage/sqlite3/send_to_device_table.go @@ -88,14 +88,14 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { func (s *sendToDeviceStatements) InsertSendToDeviceMessage( ctx context.Context, txn *sql.Tx, userID, deviceID, eventType, content string, ) (err error) { - _, err = txn.Stmt(s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, eventType, content) + _, err = internal.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, eventType, content) return } func (s *sendToDeviceStatements) SelectSendToDeviceMessages( - ctx context.Context, userID, deviceID string, + ctx context.Context, txn *sql.Tx, userID, deviceID string, ) (events []types.SendToDeviceEvent, err error) { - rows, err := s.selectSendToDeviceMessagesStmt.QueryContext(ctx, userID, deviceID) + rows, err := internal.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID) if err != nil { return } diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index bb8554f43..1d35e18ca 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -512,6 +512,17 @@ func TestGetEventsInRangeWithEventsInsertedLikeBackfill(t *testing.T) { } } +func TestSendToDeviceBehaviour(t *testing.T) { + //t.Parallel() + db := MustCreateDatabase(t) + + initial, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0)) + if err != nil { + t.Fatal(err) + } + fmt.Println("Initial:", initial) +} + func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatrixserverlib.ClientEvent, wants []gomatrixserverlib.HeaderedEvent) { if len(gots) != len(wants) { t.Fatalf("%s response returned %d events, want %d", msg, len(gots), len(wants)) diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 43527b25a..9bdbb4d0a 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -97,7 +97,7 @@ type BackwardsExtremities interface { type SendToDevice interface { InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, eventType, content string) (err error) - SelectSendToDeviceMessages(ctx context.Context, userID, deviceID string) (events []types.SendToDeviceEvent, err error) + SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (events []types.SendToDeviceEvent, err error) UpdateSentSendToDeviceMessages(ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID) (err error) DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error) }