mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-26 00:03:09 -06:00
Ensure send-to-device uses real positions too
This commit is contained in:
parent
92eadf18b9
commit
281643567b
|
|
@ -94,10 +94,8 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(msg *sarama.ConsumerMessage)
|
||||||
"event_type": output.Type,
|
"event_type": output.Type,
|
||||||
}).Info("sync API received send-to-device event from EDU server")
|
}).Info("sync API received send-to-device event from EDU server")
|
||||||
|
|
||||||
streamPos := s.db.AddSendToDevice()
|
streamPos, err := s.db.StoreNewSendForDeviceMessage(
|
||||||
|
context.TODO(), output.UserID, output.DeviceID, output.SendToDeviceEvent,
|
||||||
_, err = s.db.StoreNewSendForDeviceMessage(
|
|
||||||
context.TODO(), streamPos, output.UserID, output.DeviceID, output.SendToDeviceEvent,
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Errorf("failed to store send-to-device message")
|
log.WithError(err).Errorf("failed to store send-to-device message")
|
||||||
|
|
|
||||||
|
|
@ -130,9 +130,9 @@ type Database interface {
|
||||||
// can be deleted altogether by CleanSendToDeviceUpdates
|
// can be deleted altogether by CleanSendToDeviceUpdates
|
||||||
// The token supplied should be the current requested sync token, e.g. from the "since"
|
// The token supplied should be the current requested sync token, e.g. from the "since"
|
||||||
// parameter.
|
// parameter.
|
||||||
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) (events []types.SendToDeviceEvent, changes []types.SendToDeviceNID, deletions []types.SendToDeviceNID, err error)
|
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) (pos types.StreamPosition, events []types.SendToDeviceEvent, changes []types.SendToDeviceNID, deletions []types.SendToDeviceNID, err error)
|
||||||
// StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device.
|
// StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device.
|
||||||
StoreNewSendForDeviceMessage(ctx context.Context, streamPos types.StreamPosition, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
|
StoreNewSendForDeviceMessage(ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
|
||||||
// CleanSendToDeviceUpdates will update or remove any send-to-device updates based on the
|
// CleanSendToDeviceUpdates will update or remove any send-to-device updates based on the
|
||||||
// result to a previous call to SendDeviceUpdatesForSync. This is separate as it allows
|
// result to a previous call to SendDeviceUpdatesForSync. This is separate as it allows
|
||||||
// SendToDeviceUpdatesForSync to be called multiple times if needed (e.g. before and after
|
// SendToDeviceUpdatesForSync to be called multiple times if needed (e.g. before and after
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,7 @@ CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
|
||||||
const insertSendToDeviceMessageSQL = `
|
const insertSendToDeviceMessageSQL = `
|
||||||
INSERT INTO syncapi_send_to_device (user_id, device_id, content)
|
INSERT INTO syncapi_send_to_device (user_id, device_id, content)
|
||||||
VALUES ($1, $2, $3)
|
VALUES ($1, $2, $3)
|
||||||
|
RETURNING id
|
||||||
`
|
`
|
||||||
|
|
||||||
const countSendToDeviceMessagesSQL = `
|
const countSendToDeviceMessagesSQL = `
|
||||||
|
|
@ -107,8 +108,8 @@ func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
|
||||||
|
|
||||||
func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
|
func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
|
||||||
ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
|
ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
|
||||||
) (err error) {
|
) (pos types.StreamPosition, err error) {
|
||||||
_, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
|
err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).QueryRowContext(ctx, userID, deviceID, content).Scan(&pos)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -124,7 +125,7 @@ func (s *sendToDeviceStatements) CountSendToDeviceMessages(
|
||||||
|
|
||||||
func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
|
func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
|
||||||
ctx context.Context, txn *sql.Tx, userID, deviceID string,
|
ctx context.Context, txn *sql.Tx, userID, deviceID string,
|
||||||
) (events []types.SendToDeviceEvent, err error) {
|
) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
|
||||||
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID)
|
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
@ -152,9 +153,12 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
events = append(events, event)
|
events = append(events, event)
|
||||||
|
if types.StreamPosition(id) > lastPos {
|
||||||
|
lastPos = types.StreamPosition(id)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return events, rows.Err()
|
return lastPos, events, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
|
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
|
||||||
|
|
|
||||||
|
|
@ -1381,39 +1381,40 @@ func (d *Database) SendToDeviceUpdatesWaiting(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) StoreNewSendForDeviceMessage(
|
func (d *Database) StoreNewSendForDeviceMessage(
|
||||||
ctx context.Context, streamPos types.StreamPosition, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent,
|
ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent,
|
||||||
) (types.StreamPosition, error) {
|
) (newPos types.StreamPosition, err error) {
|
||||||
j, err := json.Marshal(event)
|
j, err := json.Marshal(event)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return streamPos, err
|
return 0, err
|
||||||
}
|
}
|
||||||
// Delegate the database write task to the SendToDeviceWriter. It'll guarantee
|
// Delegate the database write task to the SendToDeviceWriter. It'll guarantee
|
||||||
// that we don't lock the table for writes in more than one place.
|
// that we don't lock the table for writes in more than one place.
|
||||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
return d.SendToDevice.InsertSendToDeviceMessage(
|
newPos, err = d.SendToDevice.InsertSendToDeviceMessage(
|
||||||
ctx, txn, userID, deviceID, string(j),
|
ctx, txn, userID, deviceID, string(j),
|
||||||
)
|
)
|
||||||
|
return err
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return streamPos, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return streamPos, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) SendToDeviceUpdatesForSync(
|
func (d *Database) SendToDeviceUpdatesForSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
userID, deviceID string,
|
userID, deviceID string,
|
||||||
token types.StreamingToken,
|
token types.StreamingToken,
|
||||||
) ([]types.SendToDeviceEvent, []types.SendToDeviceNID, []types.SendToDeviceNID, error) {
|
) (types.StreamPosition, []types.SendToDeviceEvent, []types.SendToDeviceNID, []types.SendToDeviceNID, error) {
|
||||||
// First of all, get our send-to-device updates for this user.
|
// First of all, get our send-to-device updates for this user.
|
||||||
events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID)
|
lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
|
return 0, nil, nil, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there's nothing to do then stop here.
|
// If there's nothing to do then stop here.
|
||||||
if len(events) == 0 {
|
if len(events) == 0 {
|
||||||
return nil, nil, nil, nil
|
return 0, nil, nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Work out whether we need to update any of the database entries.
|
// Work out whether we need to update any of the database entries.
|
||||||
|
|
@ -1440,7 +1441,7 @@ func (d *Database) SendToDeviceUpdatesForSync(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return toReturn, toUpdate, toDelete, nil
|
return lastPos, toReturn, toUpdate, toDelete, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) CleanSendToDeviceUpdates(
|
func (d *Database) CleanSendToDeviceUpdates(
|
||||||
|
|
|
||||||
|
|
@ -100,8 +100,14 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
|
||||||
|
|
||||||
func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
|
func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
|
||||||
ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
|
ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
|
||||||
) (err error) {
|
) (pos types.StreamPosition, err error) {
|
||||||
_, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
|
var result sql.Result
|
||||||
|
result, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
|
||||||
|
if p, err := result.LastInsertId(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
} else {
|
||||||
|
pos = types.StreamPosition(p)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -117,7 +123,7 @@ func (s *sendToDeviceStatements) CountSendToDeviceMessages(
|
||||||
|
|
||||||
func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
|
func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
|
||||||
ctx context.Context, txn *sql.Tx, userID, deviceID string,
|
ctx context.Context, txn *sql.Tx, userID, deviceID string,
|
||||||
) (events []types.SendToDeviceEvent, err error) {
|
) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
|
||||||
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID)
|
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
@ -145,9 +151,12 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
events = append(events, event)
|
events = append(events, event)
|
||||||
|
if types.StreamPosition(id) > lastPos {
|
||||||
|
lastPos = types.StreamPosition(id)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return events, rows.Err()
|
return lastPos, events, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
|
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
|
||||||
|
|
|
||||||
|
|
@ -539,7 +539,7 @@ func TestSendToDeviceBehaviour(t *testing.T) {
|
||||||
|
|
||||||
// At this point there should be no messages. We haven't sent anything
|
// At this point there should be no messages. We haven't sent anything
|
||||||
// yet.
|
// yet.
|
||||||
events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{})
|
_, events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
@ -552,7 +552,7 @@ func TestSendToDeviceBehaviour(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try sending a message.
|
// Try sending a message.
|
||||||
streamPos, err := db.StoreNewSendForDeviceMessage(ctx, types.StreamPosition(0), "alice", "one", gomatrixserverlib.SendToDeviceEvent{
|
streamPos, err := db.StoreNewSendForDeviceMessage(ctx, "alice", "one", gomatrixserverlib.SendToDeviceEvent{
|
||||||
Sender: "bob",
|
Sender: "bob",
|
||||||
Type: "m.type",
|
Type: "m.type",
|
||||||
Content: json.RawMessage("{}"),
|
Content: json.RawMessage("{}"),
|
||||||
|
|
@ -564,7 +564,7 @@ func TestSendToDeviceBehaviour(t *testing.T) {
|
||||||
// At this point we should get exactly one message. We're sending the sync position
|
// At this point we should get exactly one message. We're sending the sync position
|
||||||
// that we were given from the update and the send-to-device update will be updated
|
// that we were given from the update and the send-to-device update will be updated
|
||||||
// in the database to reflect that this was the sync position we sent the message at.
|
// in the database to reflect that this was the sync position we sent the message at.
|
||||||
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos})
|
_, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
@ -579,7 +579,7 @@ func TestSendToDeviceBehaviour(t *testing.T) {
|
||||||
// At this point we should still have one message because we haven't progressed the
|
// At this point we should still have one message because we haven't progressed the
|
||||||
// sync position yet. This is equivalent to the client failing to /sync and retrying
|
// sync position yet. This is equivalent to the client failing to /sync and retrying
|
||||||
// with the same position.
|
// with the same position.
|
||||||
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos})
|
_, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
@ -593,7 +593,7 @@ func TestSendToDeviceBehaviour(t *testing.T) {
|
||||||
|
|
||||||
// At this point we should now have no updates, because we've progressed the sync
|
// At this point we should now have no updates, because we've progressed the sync
|
||||||
// position. Therefore the update from before will not be sent again.
|
// position. Therefore the update from before will not be sent again.
|
||||||
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos + 1})
|
_, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos + 1})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
@ -607,7 +607,7 @@ func TestSendToDeviceBehaviour(t *testing.T) {
|
||||||
|
|
||||||
// At this point we should still have no updates, because no new updates have been
|
// At this point we should still have no updates, because no new updates have been
|
||||||
// sent.
|
// sent.
|
||||||
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos + 2})
|
_, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos + 2})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -146,8 +146,8 @@ type BackwardsExtremities interface {
|
||||||
// sync parameter isn't later then we will keep including the updates in the
|
// sync parameter isn't later then we will keep including the updates in the
|
||||||
// sync response, as the client is seemingly trying to repeat the same /sync.
|
// sync response, as the client is seemingly trying to repeat the same /sync.
|
||||||
type SendToDevice interface {
|
type SendToDevice interface {
|
||||||
InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (err error)
|
InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (pos types.StreamPosition, err error)
|
||||||
SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (events []types.SendToDeviceEvent, err error)
|
SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error)
|
||||||
UpdateSentSendToDeviceMessages(ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID) (err error)
|
UpdateSentSendToDeviceMessages(ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID) (err error)
|
||||||
DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error)
|
DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error)
|
||||||
CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error)
|
CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error)
|
||||||
|
|
|
||||||
|
|
@ -278,7 +278,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
|
||||||
res := types.NewResponse()
|
res := types.NewResponse()
|
||||||
|
|
||||||
// See if we have any new tasks to do for the send-to-device messaging.
|
// See if we have any new tasks to do for the send-to-device messaging.
|
||||||
events, updates, deletions, err := rp.db.SendToDeviceUpdatesForSync(req.ctx, req.device.UserID, req.device.ID, req.since)
|
lastPos, events, updates, deletions, err := rp.db.SendToDeviceUpdatesForSync(req.ctx, req.device.UserID, req.device.ID, req.since)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("rp.db.SendToDeviceUpdatesForSync: %w", err)
|
return nil, fmt.Errorf("rp.db.SendToDeviceUpdatesForSync: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -324,10 +324,10 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
|
||||||
// Add the updates into the sync response.
|
// Add the updates into the sync response.
|
||||||
for _, event := range events {
|
for _, event := range events {
|
||||||
res.ToDevice.Events = append(res.ToDevice.Events, event.SendToDeviceEvent)
|
res.ToDevice.Events = append(res.ToDevice.Events, event.SendToDeviceEvent)
|
||||||
res.NextBatch.SendToDevicePosition++
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
res.NextBatch.SendToDevicePosition = lastPos
|
||||||
return res, err
|
return res, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue