This commit is contained in:
Neil Alexander 2021-01-13 15:20:41 +00:00
parent 19ef429b55
commit 3a27ce69cb
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
4 changed files with 23 additions and 16 deletions

View file

@ -50,7 +50,7 @@ const insertSendToDeviceMessageSQL = `
const selectSendToDeviceMessagesSQL = ` const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content SELECT id, user_id, device_id, content
FROM syncapi_send_to_device FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2 AND id <= $3 WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
ORDER BY id DESC ORDER BY id DESC
` `
@ -98,9 +98,9 @@ func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
} }
func (s *sendToDeviceStatements) SelectSendToDeviceMessages( func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string, to types.StreamPosition, ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition,
) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) { ) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID, to) rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID, from, to)
if err != nil { if err != nil {
return return
} }
@ -112,6 +112,9 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil { if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil {
return return
} }
if id > lastPos {
lastPos = id
}
event := types.SendToDeviceEvent{ event := types.SendToDeviceEvent{
ID: id, ID: id,
UserID: userID, UserID: userID,
@ -121,11 +124,10 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
continue continue
} }
events = append(events, event) events = append(events, event)
if id > lastPos {
lastPos = id
}
} }
if lastPos == 0 {
lastPos = to
}
return lastPos, events, rows.Err() return lastPos, events, rows.Err()
} }

View file

@ -903,7 +903,7 @@ func (d *Database) SendToDeviceUpdatesForSync(
from, to types.StreamPosition, from, to types.StreamPosition,
) (types.StreamPosition, []types.SendToDeviceEvent, error) { ) (types.StreamPosition, []types.SendToDeviceEvent, error) {
// First of all, get our send-to-device updates for this user. // First of all, get our send-to-device updates for this user.
lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID, to) lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID, from, to)
if err != nil { if err != nil {
return from, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err) return from, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
} }

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/sirupsen/logrus"
) )
const sendToDeviceSchema = ` const sendToDeviceSchema = `
@ -47,7 +48,7 @@ const insertSendToDeviceMessageSQL = `
const selectSendToDeviceMessagesSQL = ` const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content SELECT id, user_id, device_id, content
FROM syncapi_send_to_device FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2 AND id <= $3 WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
ORDER BY id DESC ORDER BY id DESC
` `
@ -104,9 +105,9 @@ func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
} }
func (s *sendToDeviceStatements) SelectSendToDeviceMessages( func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string, to types.StreamPosition, ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition,
) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) { ) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID, to) rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID, from, to)
if err != nil { if err != nil {
return return
} }
@ -116,22 +117,26 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
var id types.StreamPosition var id types.StreamPosition
var userID, deviceID, content string var userID, deviceID, content string
if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil { if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil {
logrus.WithError(err).Errorf("Failed to retrieve send-to-device message")
return return
} }
if id > lastPos {
lastPos = id
}
event := types.SendToDeviceEvent{ event := types.SendToDeviceEvent{
ID: id, ID: id,
UserID: userID, UserID: userID,
DeviceID: deviceID, DeviceID: deviceID,
} }
if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil { if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
logrus.WithError(err).Errorf("Failed to unmarshal send-to-device message")
continue continue
} }
events = append(events, event) events = append(events, event)
if id > lastPos {
lastPos = id
}
} }
if lastPos == 0 {
lastPos = to
}
return lastPos, events, rows.Err() return lastPos, events, rows.Err()
} }

View file

@ -147,7 +147,7 @@ type BackwardsExtremities interface {
// sync response, as the client is seemingly trying to repeat the same /sync. // sync response, as the client is seemingly trying to repeat the same /sync.
type SendToDevice interface { type SendToDevice interface {
InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (pos types.StreamPosition, err error) InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (pos types.StreamPosition, err error)
SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string, to types.StreamPosition) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error)
DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string, from types.StreamPosition) (err error) DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string, from types.StreamPosition) (err error)
SelectMaxSendToDeviceMessageID(ctx context.Context, txn *sql.Tx) (id int64, err error) SelectMaxSendToDeviceMessageID(ctx context.Context, txn *sql.Tx) (id int64, err error)
} }