This commit is contained in:
Neil Alexander 2021-01-13 15:20:41 +00:00
parent 19ef429b55
commit 3a27ce69cb
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
4 changed files with 23 additions and 16 deletions

View file

@ -50,7 +50,7 @@ const insertSendToDeviceMessageSQL = `
const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2 AND id <= $3
WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
ORDER BY id DESC
`
@ -98,9 +98,9 @@ func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
}
func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string, to types.StreamPosition,
ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition,
) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID, to)
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID, from, to)
if err != nil {
return
}
@ -112,6 +112,9 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil {
return
}
if id > lastPos {
lastPos = id
}
event := types.SendToDeviceEvent{
ID: id,
UserID: userID,
@ -121,11 +124,10 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
continue
}
events = append(events, event)
if id > lastPos {
lastPos = id
}
}
if lastPos == 0 {
lastPos = to
}
return lastPos, events, rows.Err()
}

View file

@ -903,7 +903,7 @@ func (d *Database) SendToDeviceUpdatesForSync(
from, to types.StreamPosition,
) (types.StreamPosition, []types.SendToDeviceEvent, error) {
// First of all, get our send-to-device updates for this user.
lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID, to)
lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID, from, to)
if err != nil {
return from, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
}

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/sirupsen/logrus"
)
const sendToDeviceSchema = `
@ -47,7 +48,7 @@ const insertSendToDeviceMessageSQL = `
const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2 AND id <= $3
WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
ORDER BY id DESC
`
@ -104,9 +105,9 @@ func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
}
func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string, to types.StreamPosition,
ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition,
) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID, to)
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID, from, to)
if err != nil {
return
}
@ -116,22 +117,26 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
var id types.StreamPosition
var userID, deviceID, content string
if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil {
logrus.WithError(err).Errorf("Failed to retrieve send-to-device message")
return
}
if id > lastPos {
lastPos = id
}
event := types.SendToDeviceEvent{
ID: id,
UserID: userID,
DeviceID: deviceID,
}
if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
logrus.WithError(err).Errorf("Failed to unmarshal send-to-device message")
continue
}
events = append(events, event)
if id > lastPos {
lastPos = id
}
}
if lastPos == 0 {
lastPos = to
}
return lastPos, events, rows.Err()
}

View file

@ -147,7 +147,7 @@ type BackwardsExtremities interface {
// sync response, as the client is seemingly trying to repeat the same /sync.
type SendToDevice interface {
InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (pos types.StreamPosition, err error)
SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string, to types.StreamPosition) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error)
SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error)
DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string, from types.StreamPosition) (err error)
SelectMaxSendToDeviceMessageID(ctx context.Context, txn *sql.Tx) (id int64, err error)
}