From 3a27ce69cbadba87dce51650a7f6d632a3bb77ea Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 13 Jan 2021 15:20:41 +0000 Subject: [PATCH] Fixes --- .../storage/postgres/send_to_device_table.go | 16 +++++++++------- syncapi/storage/shared/syncserver.go | 2 +- .../storage/sqlite3/send_to_device_table.go | 19 ++++++++++++------- syncapi/storage/tables/interface.go | 2 +- 4 files changed, 23 insertions(+), 16 deletions(-) diff --git a/syncapi/storage/postgres/send_to_device_table.go b/syncapi/storage/postgres/send_to_device_table.go index 3f4208906..47c1cdaed 100644 --- a/syncapi/storage/postgres/send_to_device_table.go +++ b/syncapi/storage/postgres/send_to_device_table.go @@ -50,7 +50,7 @@ const insertSendToDeviceMessageSQL = ` const selectSendToDeviceMessagesSQL = ` SELECT id, user_id, device_id, content FROM syncapi_send_to_device - WHERE user_id = $1 AND device_id = $2 AND id <= $3 + WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4 ORDER BY id DESC ` @@ -98,9 +98,9 @@ func (s *sendToDeviceStatements) InsertSendToDeviceMessage( } func (s *sendToDeviceStatements) SelectSendToDeviceMessages( - ctx context.Context, txn *sql.Tx, userID, deviceID string, to types.StreamPosition, + ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition, ) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) { - rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID, to) + rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID, from, to) if err != nil { return } @@ -112,6 +112,9 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages( if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil { return } + if id > lastPos { + lastPos = id + } event := types.SendToDeviceEvent{ ID: id, UserID: userID, @@ -121,11 +124,10 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages( continue } events = append(events, event) - if id > lastPos { - lastPos = id - } } - + if lastPos == 0 { + lastPos = to + } return lastPos, events, rows.Err() } diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 2771405c1..572d60ae3 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -903,7 +903,7 @@ func (d *Database) SendToDeviceUpdatesForSync( from, to types.StreamPosition, ) (types.StreamPosition, []types.SendToDeviceEvent, error) { // First of all, get our send-to-device updates for this user. - lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID, to) + lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID, from, to) if err != nil { return from, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err) } diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go index 45ec026bd..0b1d5bbf2 100644 --- a/syncapi/storage/sqlite3/send_to_device_table.go +++ b/syncapi/storage/sqlite3/send_to_device_table.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/sirupsen/logrus" ) const sendToDeviceSchema = ` @@ -47,7 +48,7 @@ const insertSendToDeviceMessageSQL = ` const selectSendToDeviceMessagesSQL = ` SELECT id, user_id, device_id, content FROM syncapi_send_to_device - WHERE user_id = $1 AND device_id = $2 AND id <= $3 + WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4 ORDER BY id DESC ` @@ -104,9 +105,9 @@ func (s *sendToDeviceStatements) InsertSendToDeviceMessage( } func (s *sendToDeviceStatements) SelectSendToDeviceMessages( - ctx context.Context, txn *sql.Tx, userID, deviceID string, to types.StreamPosition, + ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition, ) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) { - rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID, to) + rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID, from, to) if err != nil { return } @@ -116,22 +117,26 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages( var id types.StreamPosition var userID, deviceID, content string if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil { + logrus.WithError(err).Errorf("Failed to retrieve send-to-device message") return } + if id > lastPos { + lastPos = id + } event := types.SendToDeviceEvent{ ID: id, UserID: userID, DeviceID: deviceID, } if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil { + logrus.WithError(err).Errorf("Failed to unmarshal send-to-device message") continue } events = append(events, event) - if id > lastPos { - lastPos = id - } } - + if lastPos == 0 { + lastPos = to + } return lastPos, events, rows.Err() } diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 94070655c..fca888249 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -147,7 +147,7 @@ type BackwardsExtremities interface { // sync response, as the client is seemingly trying to repeat the same /sync. type SendToDevice interface { InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (pos types.StreamPosition, err error) - SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string, to types.StreamPosition) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) + SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string, from types.StreamPosition) (err error) SelectMaxSendToDeviceMessageID(ctx context.Context, txn *sql.Tx) (id int64, err error) }