Simplify send-to-device messaging

This commit is contained in:
Neil Alexander 2021-01-12 12:41:41 +00:00
parent ec1b017906
commit 199c5f3f88
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
13 changed files with 205 additions and 250 deletions

View file

@ -33,6 +33,7 @@ type Database interface {
MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error)
CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) ([]*gomatrixserverlib.HeaderedEvent, error) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
@ -117,26 +118,12 @@ type Database interface {
// matches the streamevent.transactionID device then the transaction ID gets // matches the streamevent.transactionID device then the transaction ID gets
// added to the unsigned section of the output event. // added to the unsigned section of the output event.
StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent
// SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns three lists: // SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the
// - "events": a list of send-to-device events that should be included in the sync // relevant events, and it automatically truncates old events once we advance past the
// - "changes": a list of send-to-device events that should be updated in the database by // stream position of the old send-to-device messages.
// CleanSendToDeviceUpdates SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error)
// - "deletions": a list of send-to-device events which have been confirmed as sent and
// can be deleted altogether by CleanSendToDeviceUpdates
// The token supplied should be the current requested sync token, e.g. from the "since"
// parameter.
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) (pos types.StreamPosition, events []types.SendToDeviceEvent, changes []types.SendToDeviceNID, deletions []types.SendToDeviceNID, err error)
// StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device. // StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device.
StoreNewSendForDeviceMessage(ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error) StoreNewSendForDeviceMessage(ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
// CleanSendToDeviceUpdates will update or remove any send-to-device updates based on the
// result to a previous call to SendDeviceUpdatesForSync. This is separate as it allows
// SendToDeviceUpdatesForSync to be called multiple times if needed (e.g. before and after
// starting to wait for an incremental sync with timeout).
// The token supplied should be the current requested sync token, e.g. from the "since"
// parameter.
CleanSendToDeviceUpdates(ctx context.Context, toUpdate, toDelete []types.SendToDeviceNID, token types.StreamingToken) (err error)
// SendToDeviceUpdatesWaiting returns true if there are send-to-device updates waiting to be sent.
SendToDeviceUpdatesWaiting(ctx context.Context, userID, deviceID string) (bool, error)
// GetFilter looks up the filter associated with a given local user and filter ID. // GetFilter looks up the filter associated with a given local user and filter ID.
// Returns a filter structure. Otherwise returns an error if no such filter exists // Returns a filter structure. Otherwise returns an error if no such filter exists
// or if there was an error talking to the database. // or if there was an error talking to the database.

View file

@ -24,6 +24,7 @@ import (
func LoadFromGoose() { func LoadFromGoose() {
goose.AddMigration(UpFixSequences, DownFixSequences) goose.AddMigration(UpFixSequences, DownFixSequences)
goose.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
} }
func LoadFixSequences(m *sqlutil.Migrations) { func LoadFixSequences(m *sqlutil.Migrations) {

View file

@ -0,0 +1,48 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
func LoadRemoveSendToDeviceSentColumn(m *sqlutil.Migrations) {
m.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
}
func UpRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE syncapi_send_to_device
DROP COLUMN sent_by_token;
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE syncapi_send_to_device
ADD COLUMN sent_by_token TEXT;
`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -19,7 +19,6 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
@ -38,11 +37,7 @@ CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
-- The device ID to send the message to. -- The device ID to send the message to.
device_id TEXT NOT NULL, device_id TEXT NOT NULL,
-- The event content JSON. -- The event content JSON.
content TEXT NOT NULL, content TEXT NOT NULL
-- The token that was supplied to the /sync at the time that this
-- message was included in a sync response, or NULL if we haven't
-- included it in a /sync response yet.
sent_by_token TEXT
); );
` `
@ -52,34 +47,26 @@ const insertSendToDeviceMessageSQL = `
RETURNING id RETURNING id
` `
const countSendToDeviceMessagesSQL = `
SELECT COUNT(*)
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2
`
const selectSendToDeviceMessagesSQL = ` const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content, sent_by_token SELECT id, user_id, device_id, content
FROM syncapi_send_to_device FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2 WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
ORDER BY id DESC ORDER BY id DESC
` `
const updateSentSendToDeviceMessagesSQL = ` const deleteSendToDeviceMessagesSQL = `
UPDATE syncapi_send_to_device SET sent_by_token = $1 DELETE FROM syncapi_send_to_device
WHERE id = ANY($2) WHERE user_id = $1 AND device_id = $2 AND id < $3
` `
const deleteSendToDeviceMessagesSQL = ` const selectMaxSendToDeviceIDSQL = "" +
DELETE FROM syncapi_send_to_device WHERE id = ANY($1) "SELECT MAX(id) FROM syncapi_send_to_device"
`
type sendToDeviceStatements struct { type sendToDeviceStatements struct {
insertSendToDeviceMessageStmt *sql.Stmt insertSendToDeviceMessageStmt *sql.Stmt
countSendToDeviceMessagesStmt *sql.Stmt
selectSendToDeviceMessagesStmt *sql.Stmt selectSendToDeviceMessagesStmt *sql.Stmt
updateSentSendToDeviceMessagesStmt *sql.Stmt
deleteSendToDeviceMessagesStmt *sql.Stmt deleteSendToDeviceMessagesStmt *sql.Stmt
selectMaxSendToDeviceIDStmt *sql.Stmt
} }
func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
@ -91,16 +78,13 @@ func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil { if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
return nil, err return nil, err
} }
if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil {
return nil, err
}
if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil { if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil {
return nil, err return nil, err
} }
if s.updateSentSendToDeviceMessagesStmt, err = db.Prepare(updateSentSendToDeviceMessagesSQL); err != nil { if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil {
return nil, err return nil, err
} }
if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil { if s.selectMaxSendToDeviceIDStmt, err = db.Prepare(selectMaxSendToDeviceIDSQL); err != nil {
return nil, err return nil, err
} }
return s, nil return s, nil
@ -113,20 +97,10 @@ func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
return return
} }
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string,
) (count int, err error) {
row := sqlutil.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID)
if err = row.Scan(&count); err != nil {
return
}
return count, nil
}
func (s *sendToDeviceStatements) SelectSendToDeviceMessages( func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string, ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition,
) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) { ) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID) rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID, from, to)
if err != nil { if err != nil {
return return
} }
@ -135,8 +109,7 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
for rows.Next() { for rows.Next() {
var id types.SendToDeviceNID var id types.SendToDeviceNID
var userID, deviceID, content string var userID, deviceID, content string
var sentByToken *string if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil {
if err = rows.Scan(&id, &userID, &deviceID, &content, &sentByToken); err != nil {
return return
} }
event := types.SendToDeviceEvent{ event := types.SendToDeviceEvent{
@ -147,11 +120,6 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil { if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
return return
} }
if sentByToken != nil {
if token, err := types.NewStreamTokenFromString(*sentByToken); err == nil {
event.SentByToken = &token
}
}
events = append(events, event) events = append(events, event)
if types.StreamPosition(id) > lastPos { if types.StreamPosition(id) > lastPos {
lastPos = types.StreamPosition(id) lastPos = types.StreamPosition(id)
@ -161,16 +129,21 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
return lastPos, events, rows.Err() return lastPos, events, rows.Err()
} }
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID, ctx context.Context, txn *sql.Tx, userID, deviceID string, pos types.StreamPosition,
) (err error) { ) (err error) {
_, err = sqlutil.TxStmt(txn, s.updateSentSendToDeviceMessagesStmt).ExecContext(ctx, token, pq.Array(nids)) _, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, userID, deviceID, pos)
return return
} }
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( func (s *sendToDeviceStatements) SelectMaxSendToDeviceMessageID(
ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID, ctx context.Context, txn *sql.Tx,
) (err error) { ) (id int64, err error) {
_, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, pq.Array(nids)) var nullableID sql.NullInt64
stmt := sqlutil.TxStmt(txn, s.selectMaxSendToDeviceIDStmt)
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
if nullableID.Valid {
id = nullableID.Int64
}
return return
} }

View file

@ -89,6 +89,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
} }
m := sqlutil.NewMigrations() m := sqlutil.NewMigrations()
deltas.LoadFixSequences(m) deltas.LoadFixSequences(m)
deltas.LoadRemoveSendToDeviceSentColumn(m)
if err = m.RunDeltas(d.db, dbProperties); err != nil { if err = m.RunDeltas(d.db, dbProperties); err != nil {
return nil, err return nil, err
} }

View file

@ -85,6 +85,14 @@ func (d *Database) MaxStreamPositionForInvites(ctx context.Context) (types.Strea
return types.StreamPosition(id), nil return types.StreamPosition(id), nil
} }
func (d *Database) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) {
id, err := d.SendToDevice.SelectMaxSendToDeviceMessageID(ctx, nil)
if err != nil {
return 0, fmt.Errorf("d.SendToDevice.SelectMaxSendToDeviceMessageID: %w", err)
}
return types.StreamPosition(id), nil
}
func (d *Database) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) { func (d *Database) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) {
id, err := d.AccountData.SelectMaxAccountDataID(ctx, nil) id, err := d.AccountData.SelectMaxAccountDataID(ctx, nil)
if err != nil { if err != nil {
@ -168,30 +176,6 @@ func (d *Database) GetEventsInStreamingRange(
return events, err return events, err
} }
/*
func (d *Database) AddTypingUser(
userID, roomID string, expireTime *time.Time,
) types.StreamPosition {
return types.StreamPosition(d.EDUCache.AddTypingUser(userID, roomID, expireTime))
}
func (d *Database) RemoveTypingUser(
userID, roomID string,
) types.StreamPosition {
return types.StreamPosition(d.EDUCache.RemoveUser(userID, roomID))
}
func (d *Database) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) {
d.EDUCache.SetTimeoutCallback(fn)
}
*/
/*
func (d *Database) AddSendToDevice() types.StreamPosition {
return types.StreamPosition(d.EDUCache.AddSendToDeviceMessage())
}
*/
func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) {
return d.CurrentRoomState.SelectJoinedUsers(ctx) return d.CurrentRoomState.SelectJoinedUsers(ctx)
} }
@ -891,16 +875,6 @@ func (d *Database) currentStateStreamEventsForRoom(
return s, nil return s, nil
} }
func (d *Database) SendToDeviceUpdatesWaiting(
ctx context.Context, userID, deviceID string,
) (bool, error) {
count, err := d.SendToDevice.CountSendToDeviceMessages(ctx, nil, userID, deviceID)
if err != nil {
return false, err
}
return count > 0, nil
}
func (d *Database) StoreNewSendForDeviceMessage( func (d *Database) StoreNewSendForDeviceMessage(
ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent, ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent,
) (newPos types.StreamPosition, err error) { ) (newPos types.StreamPosition, err error) {
@ -919,77 +893,34 @@ func (d *Database) StoreNewSendForDeviceMessage(
if err != nil { if err != nil {
return 0, err return 0, err
} }
return 0, nil return
} }
func (d *Database) SendToDeviceUpdatesForSync( func (d *Database) SendToDeviceUpdatesForSync(
ctx context.Context, ctx context.Context,
userID, deviceID string, userID, deviceID string,
token types.StreamingToken, from, to types.StreamPosition,
) (types.StreamPosition, []types.SendToDeviceEvent, []types.SendToDeviceNID, []types.SendToDeviceNID, error) { ) (types.StreamPosition, []types.SendToDeviceEvent, error) {
// First of all, get our send-to-device updates for this user. // First of all, get our send-to-device updates for this user.
lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID) lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID, from, to)
if err != nil { if err != nil {
return 0, nil, nil, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err) return 0, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
} }
// If there's nothing to do then stop here. // If there's nothing to do then stop here.
if len(events) == 0 { if len(events) == 0 {
return 0, nil, nil, nil, nil return 0, nil, fmt.Errorf("no send-to-device messages for user %q device %q in range %d -> %d", userID, deviceID, from, to)
} }
// Work out whether we need to update any of the database entries. // If we've advanced past this stream position for this
toReturn := []types.SendToDeviceEvent{} // user+device combo then clean up behind.
toUpdate := []types.SendToDeviceNID{} if err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
toDelete := []types.SendToDeviceNID{} return d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, userID, deviceID, from)
for _, event := range events { }); err != nil {
if event.SentByToken == nil { return 0, nil, fmt.Errorf("d.Writer.Do: %w", err)
// If the event has no sent-by token yet then we haven't attempted to send
// it. Record the current requested sync token in the database.
toUpdate = append(toUpdate, event.ID)
toReturn = append(toReturn, event)
event.SentByToken = &token
} else if token.IsAfter(*event.SentByToken) {
// The event had a sync token, therefore we've sent it before. The current
// sync token is now after the stored one so we can assume that the client
// successfully completed the previous sync (it would re-request it otherwise)
// so we can remove the entry from the database.
toDelete = append(toDelete, event.ID)
} else {
// It looks like the sync is being re-requested, maybe it timed out or
// failed. Re-send any that should have been acknowledged by now.
toReturn = append(toReturn, event)
}
} }
return lastPos, toReturn, toUpdate, toDelete, nil return lastPos, events, nil
}
func (d *Database) CleanSendToDeviceUpdates(
ctx context.Context,
toUpdate, toDelete []types.SendToDeviceNID,
token types.StreamingToken,
) (err error) {
if len(toUpdate) == 0 && len(toDelete) == 0 {
return nil
}
// If we need to write to the database then we'll ask the SendToDeviceWriter to
// do that for us. It'll guarantee that we don't lock the table for writes in
// more than one place.
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// Delete any send-to-device messages marked for deletion.
if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil {
return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e)
}
// Now update any outstanding send-to-device messages with the new sync token.
if e := d.SendToDevice.UpdateSentSendToDeviceMessages(ctx, txn, token.String(), toUpdate); e != nil {
return fmt.Errorf("d.SendToDevice.UpdateSentSendToDeviceMessages: %w", err)
}
return nil
})
return
} }
// getMembershipFromEvent returns the value of content.membership iff the event is a state event // getMembershipFromEvent returns the value of content.membership iff the event is a state event

View file

@ -24,6 +24,7 @@ import (
func LoadFromGoose() { func LoadFromGoose() {
goose.AddMigration(UpFixSequences, DownFixSequences) goose.AddMigration(UpFixSequences, DownFixSequences)
goose.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
} }
func LoadFixSequences(m *sqlutil.Migrations) { func LoadFixSequences(m *sqlutil.Migrations) {

View file

@ -0,0 +1,48 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
func LoadRemoveSendToDeviceSentColumn(m *sqlutil.Migrations) {
m.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
}
func UpRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE syncapi_send_to_device
DROP COLUMN sent_by_token;
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE syncapi_send_to_device
ADD COLUMN sent_by_token TEXT;
`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -18,7 +18,6 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"strings"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
@ -36,11 +35,7 @@ CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
-- The device ID to send the message to. -- The device ID to send the message to.
device_id TEXT NOT NULL, device_id TEXT NOT NULL,
-- The event content JSON. -- The event content JSON.
content TEXT NOT NULL, content TEXT NOT NULL
-- The token that was supplied to the /sync at the time that this
-- message was included in a sync response, or NULL if we haven't
-- included it in a /sync response yet.
sent_by_token TEXT
); );
` `
@ -49,33 +44,27 @@ const insertSendToDeviceMessageSQL = `
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
` `
const countSendToDeviceMessagesSQL = `
SELECT COUNT(*)
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2
`
const selectSendToDeviceMessagesSQL = ` const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content, sent_by_token SELECT id, user_id, device_id, content
FROM syncapi_send_to_device FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2 WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
ORDER BY id DESC ORDER BY id DESC
` `
const updateSentSendToDeviceMessagesSQL = ` const deleteSendToDeviceMessagesSQL = `
UPDATE syncapi_send_to_device SET sent_by_token = $1 DELETE FROM syncapi_send_to_device
WHERE id IN ($2) WHERE user_id = $1 AND device_id = $2 AND id < $3
` `
const deleteSendToDeviceMessagesSQL = ` const selectMaxSendToDeviceIDSQL = "" +
DELETE FROM syncapi_send_to_device WHERE id IN ($1) "SELECT MAX(id) FROM syncapi_send_to_device"
`
type sendToDeviceStatements struct { type sendToDeviceStatements struct {
db *sql.DB db *sql.DB
insertSendToDeviceMessageStmt *sql.Stmt insertSendToDeviceMessageStmt *sql.Stmt
selectSendToDeviceMessagesStmt *sql.Stmt selectSendToDeviceMessagesStmt *sql.Stmt
countSendToDeviceMessagesStmt *sql.Stmt deleteSendToDeviceMessagesStmt *sql.Stmt
selectMaxSendToDeviceIDStmt *sql.Stmt
} }
func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
@ -86,15 +75,18 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil {
return nil, err
}
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil { if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
return nil, err return nil, err
} }
if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil { if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil {
return nil, err return nil, err
} }
if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil {
return nil, err
}
if s.selectMaxSendToDeviceIDStmt, err = db.Prepare(selectMaxSendToDeviceIDSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
@ -111,20 +103,10 @@ func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
return return
} }
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string,
) (count int, err error) {
row := sqlutil.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID)
if err = row.Scan(&count); err != nil {
return
}
return count, nil
}
func (s *sendToDeviceStatements) SelectSendToDeviceMessages( func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string, ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition,
) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) { ) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID) rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID, from, to)
if err != nil { if err != nil {
return return
} }
@ -133,8 +115,7 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
for rows.Next() { for rows.Next() {
var id types.SendToDeviceNID var id types.SendToDeviceNID
var userID, deviceID, content string var userID, deviceID, content string
var sentByToken *string if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil {
if err = rows.Scan(&id, &userID, &deviceID, &content, &sentByToken); err != nil {
return return
} }
event := types.SendToDeviceEvent{ event := types.SendToDeviceEvent{
@ -145,11 +126,6 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil { if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
return return
} }
if sentByToken != nil {
if token, err := types.NewStreamTokenFromString(*sentByToken); err == nil {
event.SentByToken = &token
}
}
events = append(events, event) events = append(events, event)
if types.StreamPosition(id) > lastPos { if types.StreamPosition(id) > lastPos {
lastPos = types.StreamPosition(id) lastPos = types.StreamPosition(id)
@ -159,27 +135,21 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
return lastPos, events, rows.Err() return lastPos, events, rows.Err()
} }
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID, ctx context.Context, txn *sql.Tx, userID, deviceID string, pos types.StreamPosition,
) (err error) { ) (err error) {
query := strings.Replace(updateSentSendToDeviceMessagesSQL, "($2)", sqlutil.QueryVariadic(1+len(nids)), 1) _, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, userID, deviceID, pos)
params := make([]interface{}, 1+len(nids))
params[0] = token
for k, v := range nids {
params[k+1] = v
}
_, err = txn.ExecContext(ctx, query, params...)
return return
} }
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( func (s *sendToDeviceStatements) SelectMaxSendToDeviceMessageID(
ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID, ctx context.Context, txn *sql.Tx,
) (err error) { ) (id int64, err error) {
query := strings.Replace(deleteSendToDeviceMessagesSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) var nullableID sql.NullInt64
params := make([]interface{}, 1+len(nids)) stmt := sqlutil.TxStmt(txn, s.selectMaxSendToDeviceIDStmt)
for k, v := range nids { err = stmt.QueryRowContext(ctx).Scan(&nullableID)
params[k] = v if nullableID.Valid {
id = nullableID.Int64
} }
_, err = txn.ExecContext(ctx, query, params...)
return return
} }

View file

@ -102,6 +102,7 @@ func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (er
} }
m := sqlutil.NewMigrations() m := sqlutil.NewMigrations()
deltas.LoadFixSequences(m) deltas.LoadFixSequences(m)
deltas.LoadRemoveSendToDeviceSentColumn(m)
if err = m.RunDeltas(d.db, dbProperties); err != nil { if err = m.RunDeltas(d.db, dbProperties); err != nil {
return err return err
} }

View file

@ -147,10 +147,9 @@ type BackwardsExtremities interface {
// sync response, as the client is seemingly trying to repeat the same /sync. // sync response, as the client is seemingly trying to repeat the same /sync.
type SendToDevice interface { type SendToDevice interface {
InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (pos types.StreamPosition, err error) InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (pos types.StreamPosition, err error)
SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error)
UpdateSentSendToDeviceMessages(ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID) (err error) DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string, pos types.StreamPosition) (err error)
DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error) SelectMaxSendToDeviceMessageID(ctx context.Context, txn *sql.Tx) (id int64, err error)
CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error)
} }
type Filter interface { type Filter interface {

View file

@ -10,6 +10,16 @@ type SendToDeviceStreamProvider struct {
StreamProvider StreamProvider
} }
func (p *SendToDeviceStreamProvider) Setup() {
p.StreamProvider.Setup()
id, err := p.DB.MaxStreamPositionForSendToDeviceMessages(context.Background())
if err != nil {
panic(err)
}
p.latest = id
}
func (p *SendToDeviceStreamProvider) CompleteSync( func (p *SendToDeviceStreamProvider) CompleteSync(
ctx context.Context, ctx context.Context,
req *types.SyncRequest, req *types.SyncRequest,
@ -23,29 +33,15 @@ func (p *SendToDeviceStreamProvider) IncrementalSync(
from, to types.StreamPosition, from, to types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {
// See if we have any new tasks to do for the send-to-device messaging. // See if we have any new tasks to do for the send-to-device messaging.
lastPos, events, updates, deletions, err := p.DB.SendToDeviceUpdatesForSync(req.Context, req.Device.UserID, req.Device.ID, req.Since) lastPos, events, err := p.DB.SendToDeviceUpdatesForSync(req.Context, req.Device.UserID, req.Device.ID, from, to)
if err != nil { if err != nil {
req.Log.WithError(err).Error("p.DB.SendToDeviceUpdatesForSync failed") req.Log.WithError(err).Error("p.DB.SendToDeviceUpdatesForSync failed")
return from return from
} }
// Before we return the sync response, make sure that we take action on
// any send-to-device database updates or deletions that we need to do.
// Then add the updates into the sync response.
if len(updates) > 0 || len(deletions) > 0 {
// Handle the updates and deletions in the database.
err = p.DB.CleanSendToDeviceUpdates(context.Background(), updates, deletions, req.Since)
if err != nil {
req.Log.WithError(err).Error("p.DB.CleanSendToDeviceUpdates failed")
return from
}
}
if len(events) > 0 {
// Add the updates into the sync response. // Add the updates into the sync response.
for _, event := range events { for _, event := range events {
req.Response.ToDevice.Events = append(req.Response.ToDevice.Events, event.SendToDeviceEvent) req.Response.ToDevice.Events = append(req.Response.ToDevice.Events, event.SendToDeviceEvent)
} }
}
return lastPos return lastPos
} }

View file

@ -499,7 +499,6 @@ type SendToDeviceEvent struct {
ID SendToDeviceNID ID SendToDeviceNID
UserID string UserID string
DeviceID string DeviceID string
SentByToken *StreamingToken
} }
type PeekingDevice struct { type PeekingDevice struct {