diff --git a/eduserver/cache/cache.go b/eduserver/cache/cache.go index 7fc6c327e..7d8beb7be 100644 --- a/eduserver/cache/cache.go +++ b/eduserver/cache/cache.go @@ -115,8 +115,9 @@ func (t *EDUCache) AddTypingUser( func (t *EDUCache) AddSendToDeviceMessage() int64 { t.Lock() defer t.Unlock() + latestSyncPosition := t.latestSyncPosition t.latestSyncPosition++ - return t.latestSyncPosition + return latestSyncPosition } // addUser with mutex lock & replace the previous timer. diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index d40cbe847..f83aa7f97 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -16,6 +16,7 @@ package storage import ( "context" + "database/sql" "time" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" @@ -114,4 +115,6 @@ type Database interface { StoreNewSendForDeviceMessage(ctx context.Context, streamPos types.StreamPosition, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error) // CleanSendToDeviceUpdates will update or remove any send-to-device updates based on the given sync. CleanSendToDeviceUpdates(ctx context.Context, toUpdate, toDelete []types.SendToDeviceNID, token types.StreamingToken) (err error) + // SendToDeviceUpdatesWaiting returns true if there are send-to-device updates waiting to be sent. + SendToDeviceUpdatesWaiting(ctx context.Context, txn *sql.Tx, userID, deviceID string) (bool, error) } diff --git a/syncapi/storage/postgres/send_to_device_table.go b/syncapi/storage/postgres/send_to_device_table.go index a995a5275..71c86f640 100644 --- a/syncapi/storage/postgres/send_to_device_table.go +++ b/syncapi/storage/postgres/send_to_device_table.go @@ -49,6 +49,12 @@ const insertSendToDeviceMessageSQL = ` VALUES ($1, $2, $3) ` +const countSendToDeviceMessagesSQL = ` + SELECT COUNT(*) + FROM syncapi_send_to_device + WHERE user_id = $1 AND device_id = $2 +` + const selectSendToDeviceMessagesSQL = ` SELECT id, user_id, device_id, content, sent_by_token FROM syncapi_send_to_device @@ -67,6 +73,7 @@ const deleteSendToDeviceMessagesSQL = ` type sendToDeviceStatements struct { insertSendToDeviceMessageStmt *sql.Stmt + countSendToDeviceMessagesStmt *sql.Stmt selectSendToDeviceMessagesStmt *sql.Stmt updateSentSendToDeviceMessagesStmt *sql.Stmt deleteSendToDeviceMessagesStmt *sql.Stmt @@ -81,6 +88,9 @@ func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil { return nil, err } + if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil { + return nil, err + } if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil { return nil, err } @@ -100,6 +110,16 @@ func (s *sendToDeviceStatements) InsertSendToDeviceMessage( return } +func (s *sendToDeviceStatements) CountSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, userID, deviceID string, +) (count int, err error) { + row := internal.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID) + if err = row.Scan(&count); err != nil { + return + } + return count, nil +} + func (s *sendToDeviceStatements) SelectSendToDeviceMessages( ctx context.Context, txn *sql.Tx, userID, deviceID string, ) (events []types.SendToDeviceEvent, err error) { diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 10967b76e..f362a6871 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -1035,6 +1035,17 @@ func (d *Database) currentStateStreamEventsForRoom( return s, nil } +func (d *Database) SendToDeviceUpdatesWaiting( + ctx context.Context, txn *sql.Tx, + userID, deviceID string, +) (bool, error) { + count, err := d.SendToDevice.CountSendToDeviceMessages(ctx, txn, userID, deviceID) + if err != nil { + return false, err + } + return count > 0, nil +} + func (d *Database) AddSendToDeviceEvent( ctx context.Context, txn *sql.Tx, userID, deviceID, content string, diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go index 80f120222..c5185a31f 100644 --- a/syncapi/storage/sqlite3/send_to_device_table.go +++ b/syncapi/storage/sqlite3/send_to_device_table.go @@ -47,6 +47,12 @@ const insertSendToDeviceMessageSQL = ` VALUES ($1, $2, $3) ` +const countSendToDeviceMessagesSQL = ` + SELECT COUNT(*) + FROM syncapi_send_to_device + WHERE user_id = $1 AND device_id = $2 +` + const selectSendToDeviceMessagesSQL = ` SELECT id, user_id, device_id, content, sent_by_token FROM syncapi_send_to_device @@ -66,6 +72,7 @@ const deleteSendToDeviceMessagesSQL = ` type sendToDeviceStatements struct { insertSendToDeviceMessageStmt *sql.Stmt selectSendToDeviceMessagesStmt *sql.Stmt + countSendToDeviceMessagesStmt *sql.Stmt } func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { @@ -74,6 +81,9 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { if err != nil { return nil, err } + if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil { + return nil, err + } if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil { return nil, err } @@ -90,6 +100,16 @@ func (s *sendToDeviceStatements) InsertSendToDeviceMessage( return } +func (s *sendToDeviceStatements) CountSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, userID, deviceID string, +) (count int, err error) { + row := internal.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID) + if err = row.Scan(&count); err != nil { + return + } + return count, nil +} + func (s *sendToDeviceStatements) SelectSendToDeviceMessages( ctx context.Context, txn *sql.Tx, userID, deviceID string, ) (events []types.SendToDeviceEvent, err error) { diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 0ae1d4d91..f1c136c8f 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -100,4 +100,5 @@ type SendToDevice interface { SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (events []types.SendToDeviceEvent, err error) UpdateSentSendToDeviceMessages(ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID) (err error) DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error) + CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error) } diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 725952ed8..0e141b05a 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -66,7 +66,16 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype currPos := rp.notifier.CurrentPosition() - if shouldReturnImmediately(syncReq) { + returnImmediately := shouldReturnImmediately(syncReq) + if !returnImmediately { + if waiting, werr := rp.db.SendToDeviceUpdatesWaiting( + context.TODO(), nil, device.UserID, device.ID, + ); werr == nil { + returnImmediately = waiting + } + } + + if returnImmediately { syncData, err = rp.currentSyncForUser(*syncReq, currPos) if err != nil { logger.WithError(err).Error("rp.currentSyncForUser failed") @@ -118,7 +127,6 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype // response. This ensures that we don't waste the hard work // of calculating the sync only to get timed out before we // can respond - syncData, err = rp.currentSyncForUser(*syncReq, currPos) if err != nil { logger.WithError(err).Error("rp.currentSyncForUser failed") @@ -139,7 +147,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea res = types.NewResponse() // See if we have any new tasks to do for the send-to-device messaging. - events, updates, deletions, err := rp.db.SendToDeviceUpdatesForSync(req.ctx, req.device.UserID, req.device.ID, latestPos) + events, updates, deletions, err := rp.db.SendToDeviceUpdatesForSync(req.ctx, req.device.UserID, req.device.ID, *req.since) if err != nil { return nil, err } @@ -150,7 +158,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea defer func() { if len(updates) > 0 || len(deletions) > 0 { // Handle the updates and deletions in the database. - err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, latestPos) + err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, *req.since) if err != nil { return } @@ -170,15 +178,6 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea } }() - if len(events) > 0 { - // This is a bit of a hack until we can do something better with the sync API - // than this mess. If we have pending send-to-device updates then we want to - // deliver them pretty quickly. We still want the next step to run so that the - // sync tokens are updated properly. Set a zero timeout on the next step so - // that we return immediately. - req.timeout = 0 - } - // TODO: handle ignored users if req.since == nil { res, err = rp.db.CompleteSync(req.ctx, res, req.device, req.limit)