diff --git a/eduserver/cache/cache.go b/eduserver/cache/cache.go index 2908c3f7e..b37550892 100644 --- a/eduserver/cache/cache.go +++ b/eduserver/cache/cache.go @@ -115,9 +115,9 @@ func (t *EDUCache) AddTypingUser( func (t *EDUCache) AddSendToDeviceMessage() int64 { t.Lock() defer t.Unlock() - + r := t.latestSyncPosition t.latestSyncPosition++ - return t.latestSyncPosition - 1 + return r } // addUser with mutex lock & replace the previous timer. diff --git a/internal/sql.go b/internal/sql.go index d6a5a3086..bbfecae42 100644 --- a/internal/sql.go +++ b/internal/sql.go @@ -19,6 +19,8 @@ import ( "fmt" "runtime" "time" + + "go.uber.org/atomic" ) // A Transaction is something that can be committed or rolledback. @@ -107,3 +109,44 @@ type DbProperties interface { MaxOpenConns() int ConnMaxLifetime() time.Duration } + +type TransactionWriter struct { + running atomic.Bool + todo chan transactionWriterTask +} + +type transactionWriterTask struct { + db *sql.DB + f func(txn *sql.Tx) + wait chan struct{} +} + +func (w *TransactionWriter) Do(db *sql.DB, f func(txn *sql.Tx)) { + if w.todo == nil { + w.todo = make(chan transactionWriterTask) + } + if !w.running.Load() { + go w.run() + } + task := transactionWriterTask{ + db: db, + f: f, + wait: make(chan struct{}), + } + w.todo <- task + <-task.wait +} + +func (w *TransactionWriter) run() { + if !w.running.CAS(false, true) { + return + } + defer w.running.Store(false) + for task := range w.todo { + _ = WithTransaction(task.db, func(txn *sql.Tx) error { + task.f(txn) + return nil + }) + close(task.wait) + } +} diff --git a/syncapi/consumers/eduserver_sendtodevice.go b/syncapi/consumers/eduserver_sendtodevice.go index 356d2c4b4..daaaf06c8 100644 --- a/syncapi/consumers/eduserver_sendtodevice.go +++ b/syncapi/consumers/eduserver_sendtodevice.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/sync" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" log "github.com/sirupsen/logrus" ) @@ -33,6 +34,7 @@ import ( type OutputSendToDeviceEventConsumer struct { sendToDeviceConsumer *internal.ContinualConsumer db storage.Database + serverName gomatrixserverlib.ServerName // our server name notifier *sync.Notifier } @@ -54,6 +56,7 @@ func NewOutputSendToDeviceEventConsumer( s := &OutputSendToDeviceEventConsumer{ sendToDeviceConsumer: &consumer, db: store, + serverName: cfg.Matrix.ServerName, notifier: n, } @@ -75,6 +78,14 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(msg *sarama.ConsumerMessage) return err } + _, domain, err := gomatrixserverlib.SplitID('@', output.UserID) + if err != nil { + return err + } + if domain != s.serverName { + return nil + } + util.GetLogger(context.TODO()).WithFields(log.Fields{ "sender": output.Sender, "user_id": output.UserID, diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 93471ea2d..e1833501b 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -28,6 +28,7 @@ type Database struct { CurrentRoomState tables.CurrentRoomState BackwardExtremities tables.BackwardsExtremities SendToDevice tables.SendToDevice + SendToDeviceWriter internal.TransactionWriter EDUCache *cache.EDUCache } @@ -1045,9 +1046,13 @@ func (d *Database) StoreNewSendForDeviceMessage( if err != nil { return 0, err } - err = d.AddSendToDeviceEvent( - ctx, nil, userID, deviceID, string(j), - ) + // Delegate the database write task to the SendToDeviceWriter. It'll guarantee + // that we don't lock the table for writes in more than one place. + d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) { + err = d.AddSendToDeviceEvent( + ctx, txn, userID, deviceID, string(j), + ) + }) if err != nil { return 0, err } @@ -1059,42 +1064,48 @@ func (d *Database) SendToDeviceUpdatesForSync( userID, deviceID string, token types.StreamingToken, ) (events []types.SendToDeviceEvent, err error) { - err = internal.WithTransaction(d.DB, func(txn *sql.Tx) error { - // First of all, get our send-to-device updates for this user. - events, err = d.SendToDevice.SelectSendToDeviceMessages(ctx, txn, userID, deviceID) - if err != nil { - return fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err) - } + // First of all, get our send-to-device updates for this user. + events, err = d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID) + if err != nil { + return nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err) + } - // Start by cleaning up any send-to-device messages that have older sent-by-tokens. - // This means that they were sent in a previous /sync and the client has happily - // progressed onto newer sync tokens. - toUpdate := []types.SendToDeviceNID{} - toDelete := []types.SendToDeviceNID{} - for pos, event := range events { - if event.SentByToken == nil { - // Mark the event for update and keep it in our list of return events. - toUpdate = append(toUpdate, event.ID) - event.SentByToken = &token - } else if token.IsAfter(*event.SentByToken) { - // Mark the event for deletion and remove it from our list of return events. - toDelete = append(toDelete, event.ID) - events = append(events[:pos], events[pos+1:]...) + // Start by cleaning up any send-to-device messages that have older sent-by-tokens. + // This means that they were sent in a previous /sync and the client has happily + // progressed onto newer sync tokens. + toUpdate := []types.SendToDeviceNID{} + toDelete := []types.SendToDeviceNID{} + for pos, event := range events { + if event.SentByToken == nil { + // Mark the event for update and keep it in our list of return events. + toUpdate = append(toUpdate, event.ID) + event.SentByToken = &token + } else if token.IsAfter(*event.SentByToken) { + // Mark the event for deletion and remove it from our list of return events. + toDelete = append(toDelete, event.ID) + events = append(events[:pos], events[pos+1:]...) + } + } + + // If we need to write to the database then we'll ask the SendToDeviceWriter to + // do that for us. It'll guarantee that we don't lock the table for writes in + // more than one place. + if len(toUpdate) > 0 || len(toDelete) > 0 { + d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) { + // Delete any send-to-device messages marked for deletion. + if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil { + err = fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e) + return } - } - // Delete any send-to-device messages marked for deletion. - if err := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); err != nil { - return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", err) - } + // Now update any outstanding send-to-device messages with the new sync token. + if e := d.SendToDevice.UpdateSentSendToDeviceMessages(ctx, txn, token.String(), toUpdate); e != nil { + err = fmt.Errorf("d.SendToDevice.UpdateSentSendToDeviceMessages: %w", err) + return + } + }) + } - // Now update any outstanding send-to-device messages with the new sync token. - if err := d.SendToDevice.UpdateSentSendToDeviceMessages(ctx, txn, token.String(), toUpdate); err != nil { - return fmt.Errorf("d.SendToDevice.UpdateSentSendToDeviceMessages: %w", err) - } - - return nil - }) return } diff --git a/sytest-whitelist b/sytest-whitelist index d4e6be9a4..56ed2c183 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -289,3 +289,9 @@ Existing members see new members' join events Inbound federation can receive events Inbound federation can receive redacted events Can logout current device +Can send a message directly to a device using PUT /sendToDevice +Can recv a device message using /sync +Can send a to-device message to two users which both receive it using /sync +Can recv device messages until they are acknowledged +Device messages wake up /sync +Device messages over federation wake up /sync