Handle tokens properly, return immediately if waiting send-to-device messages

This commit is contained in:
Neil Alexander 2020-06-01 11:19:43 +01:00
parent dddc9efe3e
commit 2e40e92ed1
7 changed files with 69 additions and 14 deletions

View file

@ -115,8 +115,9 @@ func (t *EDUCache) AddTypingUser(
func (t *EDUCache) AddSendToDeviceMessage() int64 {
t.Lock()
defer t.Unlock()
latestSyncPosition := t.latestSyncPosition
t.latestSyncPosition++
return t.latestSyncPosition
return latestSyncPosition
}
// addUser with mutex lock & replace the previous timer.

View file

@ -16,6 +16,7 @@ package storage
import (
"context"
"database/sql"
"time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -114,4 +115,6 @@ type Database interface {
StoreNewSendForDeviceMessage(ctx context.Context, streamPos types.StreamPosition, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
// CleanSendToDeviceUpdates will update or remove any send-to-device updates based on the given sync.
CleanSendToDeviceUpdates(ctx context.Context, toUpdate, toDelete []types.SendToDeviceNID, token types.StreamingToken) (err error)
// SendToDeviceUpdatesWaiting returns true if there are send-to-device updates waiting to be sent.
SendToDeviceUpdatesWaiting(ctx context.Context, txn *sql.Tx, userID, deviceID string) (bool, error)
}

View file

@ -49,6 +49,12 @@ const insertSendToDeviceMessageSQL = `
VALUES ($1, $2, $3)
`
const countSendToDeviceMessagesSQL = `
SELECT COUNT(*)
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2
`
const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content, sent_by_token
FROM syncapi_send_to_device
@ -67,6 +73,7 @@ const deleteSendToDeviceMessagesSQL = `
type sendToDeviceStatements struct {
insertSendToDeviceMessageStmt *sql.Stmt
countSendToDeviceMessagesStmt *sql.Stmt
selectSendToDeviceMessagesStmt *sql.Stmt
updateSentSendToDeviceMessagesStmt *sql.Stmt
deleteSendToDeviceMessagesStmt *sql.Stmt
@ -81,6 +88,9 @@ func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
return nil, err
}
if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil {
return nil, err
}
if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil {
return nil, err
}
@ -100,6 +110,16 @@ func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
return
}
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string,
) (count int, err error) {
row := internal.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID)
if err = row.Scan(&count); err != nil {
return
}
return count, nil
}
func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string,
) (events []types.SendToDeviceEvent, err error) {

View file

@ -1035,6 +1035,17 @@ func (d *Database) currentStateStreamEventsForRoom(
return s, nil
}
func (d *Database) SendToDeviceUpdatesWaiting(
ctx context.Context, txn *sql.Tx,
userID, deviceID string,
) (bool, error) {
count, err := d.SendToDevice.CountSendToDeviceMessages(ctx, txn, userID, deviceID)
if err != nil {
return false, err
}
return count > 0, nil
}
func (d *Database) AddSendToDeviceEvent(
ctx context.Context, txn *sql.Tx,
userID, deviceID, content string,

View file

@ -47,6 +47,12 @@ const insertSendToDeviceMessageSQL = `
VALUES ($1, $2, $3)
`
const countSendToDeviceMessagesSQL = `
SELECT COUNT(*)
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2
`
const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content, sent_by_token
FROM syncapi_send_to_device
@ -66,6 +72,7 @@ const deleteSendToDeviceMessagesSQL = `
type sendToDeviceStatements struct {
insertSendToDeviceMessageStmt *sql.Stmt
selectSendToDeviceMessagesStmt *sql.Stmt
countSendToDeviceMessagesStmt *sql.Stmt
}
func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
@ -74,6 +81,9 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
if err != nil {
return nil, err
}
if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil {
return nil, err
}
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
return nil, err
}
@ -90,6 +100,16 @@ func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
return
}
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string,
) (count int, err error) {
row := internal.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID)
if err = row.Scan(&count); err != nil {
return
}
return count, nil
}
func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string,
) (events []types.SendToDeviceEvent, err error) {

View file

@ -100,4 +100,5 @@ type SendToDevice interface {
SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (events []types.SendToDeviceEvent, err error)
UpdateSentSendToDeviceMessages(ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID) (err error)
DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error)
CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error)
}

View file

@ -66,7 +66,16 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
currPos := rp.notifier.CurrentPosition()
if shouldReturnImmediately(syncReq) {
returnImmediately := shouldReturnImmediately(syncReq)
if !returnImmediately {
if waiting, werr := rp.db.SendToDeviceUpdatesWaiting(
context.TODO(), nil, device.UserID, device.ID,
); werr == nil {
returnImmediately = waiting
}
}
if returnImmediately {
syncData, err = rp.currentSyncForUser(*syncReq, currPos)
if err != nil {
logger.WithError(err).Error("rp.currentSyncForUser failed")
@ -118,7 +127,6 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
// response. This ensures that we don't waste the hard work
// of calculating the sync only to get timed out before we
// can respond
syncData, err = rp.currentSyncForUser(*syncReq, currPos)
if err != nil {
logger.WithError(err).Error("rp.currentSyncForUser failed")
@ -139,7 +147,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
res = types.NewResponse()
// See if we have any new tasks to do for the send-to-device messaging.
events, updates, deletions, err := rp.db.SendToDeviceUpdatesForSync(req.ctx, req.device.UserID, req.device.ID, latestPos)
events, updates, deletions, err := rp.db.SendToDeviceUpdatesForSync(req.ctx, req.device.UserID, req.device.ID, *req.since)
if err != nil {
return nil, err
}
@ -150,7 +158,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
defer func() {
if len(updates) > 0 || len(deletions) > 0 {
// Handle the updates and deletions in the database.
err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, latestPos)
err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, *req.since)
if err != nil {
return
}
@ -170,15 +178,6 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
}
}()
if len(events) > 0 {
// This is a bit of a hack until we can do something better with the sync API
// than this mess. If we have pending send-to-device updates then we want to
// deliver them pretty quickly. We still want the next step to run so that the
// sync tokens are updated properly. Set a zero timeout on the next step so
// that we return immediately.
req.timeout = 0
}
// TODO: handle ignored users
if req.since == nil {
res, err = rp.db.CompleteSync(req.ctx, res, req.device, req.limit)