mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-20 05:13:11 -06:00
Handle tokens properly, return immediately if waiting send-to-device messages
This commit is contained in:
parent
dddc9efe3e
commit
2e40e92ed1
3
eduserver/cache/cache.go
vendored
3
eduserver/cache/cache.go
vendored
|
|
@ -115,8 +115,9 @@ func (t *EDUCache) AddTypingUser(
|
|||
func (t *EDUCache) AddSendToDeviceMessage() int64 {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
latestSyncPosition := t.latestSyncPosition
|
||||
t.latestSyncPosition++
|
||||
return t.latestSyncPosition
|
||||
return latestSyncPosition
|
||||
}
|
||||
|
||||
// addUser with mutex lock & replace the previous timer.
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ package storage
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
|
|
@ -114,4 +115,6 @@ type Database interface {
|
|||
StoreNewSendForDeviceMessage(ctx context.Context, streamPos types.StreamPosition, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
|
||||
// CleanSendToDeviceUpdates will update or remove any send-to-device updates based on the given sync.
|
||||
CleanSendToDeviceUpdates(ctx context.Context, toUpdate, toDelete []types.SendToDeviceNID, token types.StreamingToken) (err error)
|
||||
// SendToDeviceUpdatesWaiting returns true if there are send-to-device updates waiting to be sent.
|
||||
SendToDeviceUpdatesWaiting(ctx context.Context, txn *sql.Tx, userID, deviceID string) (bool, error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -49,6 +49,12 @@ const insertSendToDeviceMessageSQL = `
|
|||
VALUES ($1, $2, $3)
|
||||
`
|
||||
|
||||
const countSendToDeviceMessagesSQL = `
|
||||
SELECT COUNT(*)
|
||||
FROM syncapi_send_to_device
|
||||
WHERE user_id = $1 AND device_id = $2
|
||||
`
|
||||
|
||||
const selectSendToDeviceMessagesSQL = `
|
||||
SELECT id, user_id, device_id, content, sent_by_token
|
||||
FROM syncapi_send_to_device
|
||||
|
|
@ -67,6 +73,7 @@ const deleteSendToDeviceMessagesSQL = `
|
|||
|
||||
type sendToDeviceStatements struct {
|
||||
insertSendToDeviceMessageStmt *sql.Stmt
|
||||
countSendToDeviceMessagesStmt *sql.Stmt
|
||||
selectSendToDeviceMessagesStmt *sql.Stmt
|
||||
updateSentSendToDeviceMessagesStmt *sql.Stmt
|
||||
deleteSendToDeviceMessagesStmt *sql.Stmt
|
||||
|
|
@ -81,6 +88,9 @@ func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
|
|||
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -100,6 +110,16 @@ func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
|
|||
return
|
||||
}
|
||||
|
||||
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
|
||||
ctx context.Context, txn *sql.Tx, userID, deviceID string,
|
||||
) (count int, err error) {
|
||||
row := internal.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID)
|
||||
if err = row.Scan(&count); err != nil {
|
||||
return
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
|
||||
ctx context.Context, txn *sql.Tx, userID, deviceID string,
|
||||
) (events []types.SendToDeviceEvent, err error) {
|
||||
|
|
|
|||
|
|
@ -1035,6 +1035,17 @@ func (d *Database) currentStateStreamEventsForRoom(
|
|||
return s, nil
|
||||
}
|
||||
|
||||
func (d *Database) SendToDeviceUpdatesWaiting(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
userID, deviceID string,
|
||||
) (bool, error) {
|
||||
count, err := d.SendToDevice.CountSendToDeviceMessages(ctx, txn, userID, deviceID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (d *Database) AddSendToDeviceEvent(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
userID, deviceID, content string,
|
||||
|
|
|
|||
|
|
@ -47,6 +47,12 @@ const insertSendToDeviceMessageSQL = `
|
|||
VALUES ($1, $2, $3)
|
||||
`
|
||||
|
||||
const countSendToDeviceMessagesSQL = `
|
||||
SELECT COUNT(*)
|
||||
FROM syncapi_send_to_device
|
||||
WHERE user_id = $1 AND device_id = $2
|
||||
`
|
||||
|
||||
const selectSendToDeviceMessagesSQL = `
|
||||
SELECT id, user_id, device_id, content, sent_by_token
|
||||
FROM syncapi_send_to_device
|
||||
|
|
@ -66,6 +72,7 @@ const deleteSendToDeviceMessagesSQL = `
|
|||
type sendToDeviceStatements struct {
|
||||
insertSendToDeviceMessageStmt *sql.Stmt
|
||||
selectSendToDeviceMessagesStmt *sql.Stmt
|
||||
countSendToDeviceMessagesStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
|
||||
|
|
@ -74,6 +81,9 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -90,6 +100,16 @@ func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
|
|||
return
|
||||
}
|
||||
|
||||
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
|
||||
ctx context.Context, txn *sql.Tx, userID, deviceID string,
|
||||
) (count int, err error) {
|
||||
row := internal.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID)
|
||||
if err = row.Scan(&count); err != nil {
|
||||
return
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
|
||||
ctx context.Context, txn *sql.Tx, userID, deviceID string,
|
||||
) (events []types.SendToDeviceEvent, err error) {
|
||||
|
|
|
|||
|
|
@ -100,4 +100,5 @@ type SendToDevice interface {
|
|||
SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (events []types.SendToDeviceEvent, err error)
|
||||
UpdateSentSendToDeviceMessages(ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID) (err error)
|
||||
DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error)
|
||||
CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -66,7 +66,16 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
|
|||
|
||||
currPos := rp.notifier.CurrentPosition()
|
||||
|
||||
if shouldReturnImmediately(syncReq) {
|
||||
returnImmediately := shouldReturnImmediately(syncReq)
|
||||
if !returnImmediately {
|
||||
if waiting, werr := rp.db.SendToDeviceUpdatesWaiting(
|
||||
context.TODO(), nil, device.UserID, device.ID,
|
||||
); werr == nil {
|
||||
returnImmediately = waiting
|
||||
}
|
||||
}
|
||||
|
||||
if returnImmediately {
|
||||
syncData, err = rp.currentSyncForUser(*syncReq, currPos)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("rp.currentSyncForUser failed")
|
||||
|
|
@ -118,7 +127,6 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
|
|||
// response. This ensures that we don't waste the hard work
|
||||
// of calculating the sync only to get timed out before we
|
||||
// can respond
|
||||
|
||||
syncData, err = rp.currentSyncForUser(*syncReq, currPos)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("rp.currentSyncForUser failed")
|
||||
|
|
@ -139,7 +147,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
|
|||
res = types.NewResponse()
|
||||
|
||||
// See if we have any new tasks to do for the send-to-device messaging.
|
||||
events, updates, deletions, err := rp.db.SendToDeviceUpdatesForSync(req.ctx, req.device.UserID, req.device.ID, latestPos)
|
||||
events, updates, deletions, err := rp.db.SendToDeviceUpdatesForSync(req.ctx, req.device.UserID, req.device.ID, *req.since)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -150,7 +158,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
|
|||
defer func() {
|
||||
if len(updates) > 0 || len(deletions) > 0 {
|
||||
// Handle the updates and deletions in the database.
|
||||
err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, latestPos)
|
||||
err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, *req.since)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -170,15 +178,6 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
|
|||
}
|
||||
}()
|
||||
|
||||
if len(events) > 0 {
|
||||
// This is a bit of a hack until we can do something better with the sync API
|
||||
// than this mess. If we have pending send-to-device updates then we want to
|
||||
// deliver them pretty quickly. We still want the next step to run so that the
|
||||
// sync tokens are updated properly. Set a zero timeout on the next step so
|
||||
// that we return immediately.
|
||||
req.timeout = 0
|
||||
}
|
||||
|
||||
// TODO: handle ignored users
|
||||
if req.since == nil {
|
||||
res, err = rp.db.CompleteSync(req.ctx, res, req.device, req.limit)
|
||||
|
|
|
|||
Loading…
Reference in a new issue