diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 0685c7352..c9647eb1b 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -19,6 +19,8 @@ import ( "net/http" "time" + "github.com/matrix-org/gomatrixserverlib" + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" @@ -27,7 +29,6 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrix" "github.com/matrix-org/util" @@ -126,20 +127,6 @@ func SetAvatarURL( } } - res := &userapi.QueryProfileResponse{} - err = profileAPI.QueryProfile(req.Context(), &userapi.QueryProfileRequest{ - UserID: userID, - }, res) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("profileAPI.QueryProfile failed") - return jsonerror.InternalServerError() - } - oldProfile := &authtypes.Profile{ - Localpart: localpart, - DisplayName: res.DisplayName, - AvatarURL: res.AvatarURL, - } - setRes := &userapi.PerformSetAvatarURLResponse{} if err = profileAPI.SetAvatarURL(req.Context(), &userapi.PerformSetAvatarURLRequest{ Localpart: localpart, @@ -148,41 +135,17 @@ func SetAvatarURL( util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetAvatarURL failed") return jsonerror.InternalServerError() } - - var roomsRes api.QueryRoomsForUserResponse - err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{ - UserID: device.UserID, - WantMembership: "join", - }, &roomsRes) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed") - return jsonerror.InternalServerError() - } - - newProfile := authtypes.Profile{ - Localpart: localpart, - DisplayName: oldProfile.DisplayName, - AvatarURL: r.AvatarURL, - } - - events, err := buildMembershipEvents( - req.Context(), roomsRes.RoomIDs, newProfile, userID, cfg, evTime, rsAPI, - ) - switch e := err.(type) { - case nil: - case gomatrixserverlib.BadJSONError: + // No need to build new membership events, since nothing changed + if !setRes.Changed { return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(e.Error()), + Code: http.StatusOK, + JSON: struct{}{}, } - default: - util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvents failed") - return jsonerror.InternalServerError() } - if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") - return jsonerror.InternalServerError() + response, err := updateProfile(req.Context(), rsAPI, device, setRes.Profile, userID, cfg, evTime) + if err != nil { + return response } return util.JSONResponse{ @@ -255,47 +218,51 @@ func SetDisplayName( } } - pRes := &userapi.QueryProfileResponse{} - err = profileAPI.QueryProfile(req.Context(), &userapi.QueryProfileRequest{ - UserID: userID, - }, pRes) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("profileAPI.QueryProfile failed") - return jsonerror.InternalServerError() - } - oldProfile := &authtypes.Profile{ - Localpart: localpart, - DisplayName: pRes.DisplayName, - AvatarURL: pRes.AvatarURL, - } - + profileRes := &userapi.PerformUpdateDisplayNameResponse{} err = profileAPI.SetDisplayName(req.Context(), &userapi.PerformUpdateDisplayNameRequest{ Localpart: localpart, DisplayName: r.DisplayName, - }, &struct{}{}) + }, profileRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed") return jsonerror.InternalServerError() } + // No need to build new membership events, since nothing changed + if !profileRes.Changed { + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } + } + response, err := updateProfile(req.Context(), rsAPI, device, profileRes.Profile, userID, cfg, evTime) + if err != nil { + return response + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} + +func updateProfile( + ctx context.Context, rsAPI api.ClientRoomserverAPI, device *userapi.Device, + profile *authtypes.Profile, + userID string, cfg *config.ClientAPI, evTime time.Time, +) (util.JSONResponse, error) { var res api.QueryRoomsForUserResponse - err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{ + err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ UserID: device.UserID, WantMembership: "join", }, &res) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed") - return jsonerror.InternalServerError() - } - - newProfile := authtypes.Profile{ - Localpart: localpart, - DisplayName: r.DisplayName, - AvatarURL: oldProfile.AvatarURL, + util.GetLogger(ctx).WithError(err).Error("QueryRoomsForUser failed") + return jsonerror.InternalServerError(), err } events, err := buildMembershipEvents( - req.Context(), res.RoomIDs, newProfile, userID, cfg, evTime, rsAPI, + ctx, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI, ) switch e := err.(type) { case nil: @@ -303,21 +270,17 @@ func SetDisplayName( return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(e.Error()), - } + }, e default: - util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvents failed") - return jsonerror.InternalServerError() + util.GetLogger(ctx).WithError(err).Error("buildMembershipEvents failed") + return jsonerror.InternalServerError(), e } - if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") - return jsonerror.InternalServerError() - } - - return util.JSONResponse{ - Code: http.StatusOK, - JSON: struct{}{}, + if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil { + util.GetLogger(ctx).WithError(err).Error("SendEvents failed") + return jsonerror.InternalServerError(), err } + return util.JSONResponse{}, nil } // getProfile gets the full profile of a user by querying the database or a diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index ec5ca899e..4ca8e59c5 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -178,7 +178,7 @@ func Setup( // server notifications if cfg.Matrix.ServerNotices.Enabled { logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice") - serverNotificationSender, err := getSenderDevice(context.Background(), userAPI, cfg) + serverNotificationSender, err := getSenderDevice(context.Background(), rsAPI, userAPI, cfg) if err != nil { logrus.WithError(err).Fatal("unable to get account for sending sending server notices") } diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go index 7729eddd8..a6a78061d 100644 --- a/clientapi/routing/server_notices.go +++ b/clientapi/routing/server_notices.go @@ -277,6 +277,7 @@ func (r sendServerNoticeRequest) valid() (ok bool) { // It returns an userapi.Device, which is used for building the event func getSenderDevice( ctx context.Context, + rsAPI api.ClientRoomserverAPI, userAPI userapi.ClientUserAPI, cfg *config.ClientAPI, ) (*userapi.Device, error) { @@ -291,16 +292,32 @@ func getSenderDevice( return nil, err } - // set the avatarurl for the user - res := &userapi.PerformSetAvatarURLResponse{} + // Set the avatarurl for the user + avatarRes := &userapi.PerformSetAvatarURLResponse{} if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{ Localpart: cfg.Matrix.ServerNotices.LocalPart, AvatarURL: cfg.Matrix.ServerNotices.AvatarURL, - }, res); err != nil { + }, avatarRes); err != nil { util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed") return nil, err } + profile := avatarRes.Profile + + // Set the displayname for the user + displayNameRes := &userapi.PerformUpdateDisplayNameResponse{} + if err = userAPI.SetDisplayName(ctx, &userapi.PerformUpdateDisplayNameRequest{ + Localpart: cfg.Matrix.ServerNotices.LocalPart, + DisplayName: cfg.Matrix.ServerNotices.DisplayName, + }, displayNameRes); err != nil { + util.GetLogger(ctx).WithError(err).Error("userAPI.SetDisplayName failed") + return nil, err + } + + if displayNameRes.Changed { + profile.DisplayName = cfg.Matrix.ServerNotices.DisplayName + } + // Check if we got existing devices deviceRes := &userapi.QueryDevicesResponse{} err = userAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{ @@ -310,7 +327,15 @@ func getSenderDevice( return nil, err } + // We've got an existing account, return the first device of it if len(deviceRes.Devices) > 0 { + // If there were changes to the profile, create a new membership event + if displayNameRes.Changed || avatarRes.Changed { + _, err = updateProfile(ctx, rsAPI, &deviceRes.Devices[0], profile, accRes.Account.UserID, cfg, time.Now()) + if err != nil { + return nil, err + } + } return &deviceRes.Devices[0], nil } diff --git a/docs/Gemfile.lock b/docs/Gemfile.lock index bc73df728..c7ba43711 100644 --- a/docs/Gemfile.lock +++ b/docs/Gemfile.lock @@ -231,9 +231,9 @@ GEM jekyll-seo-tag (~> 2.1) minitest (5.15.0) multipart-post (2.1.1) - nokogiri (1.13.6-arm64-darwin) + nokogiri (1.13.9-arm64-darwin) racc (~> 1.4) - nokogiri (1.13.6-x86_64-linux) + nokogiri (1.13.9-x86_64-linux) racc (~> 1.4) octokit (4.22.0) faraday (>= 0.9) diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 4a13c9d9b..f6dace702 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -116,17 +116,14 @@ func NewInternalAPI( _ = federationDB.RemoveAllServersFromBlacklist() } - stats := &statistics.Statistics{ - DB: federationDB, - FailuresUntilBlacklist: cfg.FederationMaxRetries, - } + stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1) js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) queues := queue.NewOutgoingQueues( federationDB, base.ProcessContext, cfg.Matrix.DisableFederation, - cfg.Matrix.ServerName, federation, rsAPI, stats, + cfg.Matrix.ServerName, federation, rsAPI, &stats, &queue.SigningInfo{ KeyID: cfg.Matrix.KeyID, PrivateKey: cfg.Matrix.PrivateKey, @@ -183,5 +180,5 @@ func NewInternalAPI( } time.AfterFunc(time.Minute, cleanExpiredEDUs) - return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, stats, caches, queues, keyRing) + return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, &stats, caches, queues, keyRing) } diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index 5cb8cae1f..1b7670e9a 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -21,21 +21,22 @@ import ( "sync" "time" + "github.com/matrix-org/gomatrix" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" + "go.uber.org/atomic" + fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/gomatrix" - "github.com/matrix-org/gomatrixserverlib" - "github.com/sirupsen/logrus" - "go.uber.org/atomic" ) const ( maxPDUsPerTransaction = 50 - maxEDUsPerTransaction = 50 + maxEDUsPerTransaction = 100 maxPDUsInMemory = 128 maxEDUsInMemory = 128 queueIdleTimeout = time.Second * 30 @@ -64,7 +65,6 @@ type destinationQueue struct { pendingPDUs []*queuedPDU // PDUs waiting to be sent pendingEDUs []*queuedEDU // EDUs waiting to be sent pendingMutex sync.RWMutex // protects pendingPDUs and pendingEDUs - interruptBackoff chan bool // interrupts backoff } // Send event adds the event to the pending queue for the destination. @@ -75,39 +75,22 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination) return } - // Create a database entry that associates the given PDU NID with - // this destination queue. We'll then be able to retrieve the PDU - // later. - if err := oq.db.AssociatePDUWithDestination( - oq.process.Context(), - "", // TODO: remove this, as we don't need to persist the transaction ID - oq.destination, // the destination server name - receipt, // NIDs from federationapi_queue_json table - ); err != nil { - logrus.WithError(err).Errorf("failed to associate PDU %q with destination %q", event.EventID(), oq.destination) - return + + // If there's room in memory to hold the event then add it to the + // list. + oq.pendingMutex.Lock() + if len(oq.pendingPDUs) < maxPDUsInMemory { + oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{ + pdu: event, + receipt: receipt, + }) + } else { + oq.overflowed.Store(true) } - // Check if the destination is blacklisted. If it isn't then wake - // up the queue. - if !oq.statistics.Blacklisted() { - // If there's room in memory to hold the event then add it to the - // list. - oq.pendingMutex.Lock() - if len(oq.pendingPDUs) < maxPDUsInMemory { - oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{ - pdu: event, - receipt: receipt, - }) - } else { - oq.overflowed.Store(true) - } - oq.pendingMutex.Unlock() - // Wake up the queue if it's asleep. - oq.wakeQueueIfNeeded() - select { - case oq.notify <- struct{}{}: - default: - } + oq.pendingMutex.Unlock() + + if !oq.backingOff.Load() { + oq.wakeQueueAndNotify() } } @@ -119,40 +102,47 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share logrus.Errorf("attempt to send nil EDU with destination %q", oq.destination) return } - // Create a database entry that associates the given PDU NID with - // this destination queue. We'll then be able to retrieve the PDU - // later. - if err := oq.db.AssociateEDUWithDestination( - oq.process.Context(), - oq.destination, // the destination server name - receipt, // NIDs from federationapi_queue_json table - event.Type, - nil, // this will use the default expireEDUTypes map - ); err != nil { - logrus.WithError(err).Errorf("failed to associate EDU with destination %q", oq.destination) - return + + // If there's room in memory to hold the event then add it to the + // list. + oq.pendingMutex.Lock() + if len(oq.pendingEDUs) < maxEDUsInMemory { + oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{ + edu: event, + receipt: receipt, + }) + } else { + oq.overflowed.Store(true) } - // Check if the destination is blacklisted. If it isn't then wake - // up the queue. - if !oq.statistics.Blacklisted() { - // If there's room in memory to hold the event then add it to the - // list. - oq.pendingMutex.Lock() - if len(oq.pendingEDUs) < maxEDUsInMemory { - oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{ - edu: event, - receipt: receipt, - }) - } else { - oq.overflowed.Store(true) - } - oq.pendingMutex.Unlock() - // Wake up the queue if it's asleep. - oq.wakeQueueIfNeeded() - select { - case oq.notify <- struct{}{}: - default: - } + oq.pendingMutex.Unlock() + + if !oq.backingOff.Load() { + oq.wakeQueueAndNotify() + } +} + +// handleBackoffNotifier is registered as the backoff notification +// callback with Statistics. It will wakeup and notify the queue +// if the queue is currently backing off. +func (oq *destinationQueue) handleBackoffNotifier() { + // Only wake up the queue if it is backing off. + // Otherwise there is no pending work for the queue to handle + // so waking the queue would be a waste of resources. + if oq.backingOff.Load() { + oq.wakeQueueAndNotify() + } +} + +// wakeQueueAndNotify ensures the destination queue is running and notifies it +// that there is pending work. +func (oq *destinationQueue) wakeQueueAndNotify() { + // Wake up the queue if it's asleep. + oq.wakeQueueIfNeeded() + + // Notify the queue that there are events ready to send. + select { + case oq.notify <- struct{}{}: + default: } } @@ -161,10 +151,11 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share // then we will interrupt the backoff, causing any federation // requests to retry. func (oq *destinationQueue) wakeQueueIfNeeded() { - // If we are backing off then interrupt the backoff. + // Clear the backingOff flag and update the backoff metrics if it was set. if oq.backingOff.CompareAndSwap(true, false) { - oq.interruptBackoff <- true + destinationQueueBackingOff.Dec() } + // If we aren't running then wake up the queue. if !oq.running.Load() { // Start the queue. @@ -196,38 +187,54 @@ func (oq *destinationQueue) getPendingFromDatabase() { gotEDUs[edu.receipt.String()] = struct{}{} } + overflowed := false if pduCapacity := maxPDUsInMemory - len(oq.pendingPDUs); pduCapacity > 0 { // We have room in memory for some PDUs - let's request no more than that. - if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, pduCapacity); err == nil { + if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, maxPDUsInMemory); err == nil { + if len(pdus) == maxPDUsInMemory { + overflowed = true + } for receipt, pdu := range pdus { if _, ok := gotPDUs[receipt.String()]; ok { continue } oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{receipt, pdu}) retrieved = true + if len(oq.pendingPDUs) == maxPDUsInMemory { + break + } } } else { logrus.WithError(err).Errorf("Failed to get pending PDUs for %q", oq.destination) } } + if eduCapacity := maxEDUsInMemory - len(oq.pendingEDUs); eduCapacity > 0 { // We have room in memory for some EDUs - let's request no more than that. - if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, eduCapacity); err == nil { + if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, maxEDUsInMemory); err == nil { + if len(edus) == maxEDUsInMemory { + overflowed = true + } for receipt, edu := range edus { if _, ok := gotEDUs[receipt.String()]; ok { continue } oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{receipt, edu}) retrieved = true + if len(oq.pendingEDUs) == maxEDUsInMemory { + break + } } } else { logrus.WithError(err).Errorf("Failed to get pending EDUs for %q", oq.destination) } } + // If we've retrieved all of the events from the database with room to spare // in memory then we'll no longer consider this queue to be overflowed. - if len(oq.pendingPDUs) < maxPDUsInMemory && len(oq.pendingEDUs) < maxEDUsInMemory { + if !overflowed { oq.overflowed.Store(false) + } else { } // If we've retrieved some events then notify the destination queue goroutine. if retrieved { @@ -238,6 +245,24 @@ func (oq *destinationQueue) getPendingFromDatabase() { } } +// checkNotificationsOnClose checks for any remaining notifications +// and starts a new backgroundSend goroutine if any exist. +func (oq *destinationQueue) checkNotificationsOnClose() { + // NOTE : If we are stopping the queue due to blacklist then it + // doesn't matter if we have been notified of new work since + // this queue instance will be deleted anyway. + if !oq.statistics.Blacklisted() { + select { + case <-oq.notify: + // We received a new notification in between the + // idle timeout firing and stopping the goroutine. + // Immediately restart the queue. + oq.wakeQueueAndNotify() + default: + } + } +} + // backgroundSend is the worker goroutine for sending events. func (oq *destinationQueue) backgroundSend() { // Check if a worker is already running, and if it isn't, then @@ -245,10 +270,17 @@ func (oq *destinationQueue) backgroundSend() { if !oq.running.CompareAndSwap(false, true) { return } + + // Register queue cleanup functions. + // NOTE : The ordering here is very intentional. + defer oq.checkNotificationsOnClose() + defer oq.running.Store(false) + destinationQueueRunning.Inc() defer destinationQueueRunning.Dec() - defer oq.queues.clearQueue(oq) - defer oq.running.Store(false) + + idleTimeout := time.NewTimer(queueIdleTimeout) + defer idleTimeout.Stop() // Mark the queue as overflowed, so we will consult the database // to see if there's anything new to send. @@ -261,59 +293,33 @@ func (oq *destinationQueue) backgroundSend() { oq.getPendingFromDatabase() } + // Reset the queue idle timeout. + if !idleTimeout.Stop() { + select { + case <-idleTimeout.C: + default: + } + } + idleTimeout.Reset(queueIdleTimeout) + // If we have nothing to do then wait either for incoming events, or // until we hit an idle timeout. select { case <-oq.notify: // There's work to do, either because getPendingFromDatabase - // told us there is, or because a new event has come in via - // sendEvent/sendEDU. - case <-time.After(queueIdleTimeout): + // told us there is, a new event has come in via sendEvent/sendEDU, + // or we are backing off and it is time to retry. + case <-idleTimeout.C: // The worker is idle so stop the goroutine. It'll get // restarted automatically the next time we have an event to // send. return case <-oq.process.Context().Done(): // The parent process is shutting down, so stop. + oq.statistics.ClearBackoff() return } - // If we are backing off this server then wait for the - // backoff duration to complete first, or until explicitly - // told to retry. - until, blacklisted := oq.statistics.BackoffInfo() - if blacklisted { - // It's been suggested that we should give up because the backoff - // has exceeded a maximum allowable value. Clean up the in-memory - // buffers at this point. The PDU clean-up is already on a defer. - logrus.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination) - oq.pendingMutex.Lock() - for i := range oq.pendingPDUs { - oq.pendingPDUs[i] = nil - } - for i := range oq.pendingEDUs { - oq.pendingEDUs[i] = nil - } - oq.pendingPDUs = nil - oq.pendingEDUs = nil - oq.pendingMutex.Unlock() - return - } - if until != nil && until.After(time.Now()) { - // We haven't backed off yet, so wait for the suggested amount of - // time. - duration := time.Until(*until) - logrus.Debugf("Backing off %q for %s", oq.destination, duration) - oq.backingOff.Store(true) - destinationQueueBackingOff.Inc() - select { - case <-time.After(duration): - case <-oq.interruptBackoff: - } - destinationQueueBackingOff.Dec() - oq.backingOff.Store(false) - } - // Work out which PDUs/EDUs to include in the next transaction. oq.pendingMutex.RLock() pduCount := len(oq.pendingPDUs) @@ -328,99 +334,52 @@ func (oq *destinationQueue) backgroundSend() { toSendEDUs := oq.pendingEDUs[:eduCount] oq.pendingMutex.RUnlock() + // If we didn't get anything from the database and there are no + // pending EDUs then there's nothing to do - stop here. + if pduCount == 0 && eduCount == 0 { + continue + } + // If we have pending PDUs or EDUs then construct a transaction. // Try sending the next transaction and see what happens. - transaction, pc, ec, terr := oq.nextTransaction(toSendPDUs, toSendEDUs) + terr := oq.nextTransaction(toSendPDUs, toSendEDUs) if terr != nil { // We failed to send the transaction. Mark it as a failure. - oq.statistics.Failure() - - } else if transaction { - // If we successfully sent the transaction then clear out - // the pending events and EDUs, and wipe our transaction ID. - oq.statistics.Success() - oq.pendingMutex.Lock() - for i := range oq.pendingPDUs[:pc] { - oq.pendingPDUs[i] = nil + _, blacklisted := oq.statistics.Failure() + if !blacklisted { + // Register the backoff state and exit the goroutine. + // It'll get restarted automatically when the backoff + // completes. + oq.backingOff.Store(true) + destinationQueueBackingOff.Inc() + return + } else { + // Immediately trigger the blacklist logic. + oq.blacklistDestination() + return } - for i := range oq.pendingEDUs[:ec] { - oq.pendingEDUs[i] = nil - } - oq.pendingPDUs = oq.pendingPDUs[pc:] - oq.pendingEDUs = oq.pendingEDUs[ec:] - oq.pendingMutex.Unlock() + } else { + oq.handleTransactionSuccess(pduCount, eduCount) } } } // nextTransaction creates a new transaction from the pending event -// queue and sends it. Returns true if a transaction was sent or -// false otherwise. +// queue and sends it. +// Returns an error if the transaction wasn't sent. func (oq *destinationQueue) nextTransaction( pdus []*queuedPDU, edus []*queuedEDU, -) (bool, int, int, error) { - // If there's no projected transaction ID then generate one. If - // the transaction succeeds then we'll set it back to "" so that - // we generate a new one next time. If it fails, we'll preserve - // it so that we retry with the same transaction ID. - oq.transactionIDMutex.Lock() - if oq.transactionID == "" { - now := gomatrixserverlib.AsTimestamp(time.Now()) - oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount())) - } - oq.transactionIDMutex.Unlock() - +) error { // Create the transaction. - t := gomatrixserverlib.Transaction{ - PDUs: []json.RawMessage{}, - EDUs: []gomatrixserverlib.EDU{}, - } - t.Origin = oq.origin - t.Destination = oq.destination - t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now()) - t.TransactionID = oq.transactionID - - // If we didn't get anything from the database and there are no - // pending EDUs then there's nothing to do - stop here. - if len(pdus) == 0 && len(edus) == 0 { - return false, 0, 0, nil - } - - var pduReceipts []*shared.Receipt - var eduReceipts []*shared.Receipt - - // Go through PDUs that we retrieved from the database, if any, - // and add them into the transaction. - for _, pdu := range pdus { - if pdu == nil || pdu.pdu == nil { - continue - } - // Append the JSON of the event, since this is a json.RawMessage type in the - // gomatrixserverlib.Transaction struct - t.PDUs = append(t.PDUs, pdu.pdu.JSON()) - pduReceipts = append(pduReceipts, pdu.receipt) - } - - // Do the same for pending EDUS in the queue. - for _, edu := range edus { - if edu == nil || edu.edu == nil { - continue - } - t.EDUs = append(t.EDUs, *edu.edu) - eduReceipts = append(eduReceipts, edu.receipt) - } - + t, pduReceipts, eduReceipts := oq.createTransaction(pdus, edus) logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs)) // Try to send the transaction to the destination server. - // TODO: we should check for 500-ish fails vs 400-ish here, - // since we shouldn't queue things indefinitely in response - // to a 400-ish error ctx, cancel := context.WithTimeout(oq.process.Context(), time.Minute*5) defer cancel() _, err := oq.client.SendTransaction(ctx, t) - switch err.(type) { + switch errResponse := err.(type) { case nil: // Clean up the transaction in the database. if pduReceipts != nil { @@ -439,16 +398,129 @@ func (oq *destinationQueue) nextTransaction( oq.transactionIDMutex.Lock() oq.transactionID = "" oq.transactionIDMutex.Unlock() - return true, len(t.PDUs), len(t.EDUs), nil + return nil case gomatrix.HTTPError: // Report that we failed to send the transaction and we // will retry again, subject to backoff. - return false, 0, 0, err + + // TODO: we should check for 500-ish fails vs 400-ish here, + // since we shouldn't queue things indefinitely in response + // to a 400-ish error + code := errResponse.Code + logrus.Debug("Transaction failed with HTTP", code) + return err default: logrus.WithFields(logrus.Fields{ "destination": oq.destination, logrus.ErrorKey: err, }).Debugf("Failed to send transaction %q", t.TransactionID) - return false, 0, 0, err + return err + } +} + +// createTransaction generates a gomatrixserverlib.Transaction from the provided pdus and edus. +// It also returns the associated event receipts so they can be cleaned from the database in +// the case of a successful transaction. +func (oq *destinationQueue) createTransaction( + pdus []*queuedPDU, + edus []*queuedEDU, +) (gomatrixserverlib.Transaction, []*shared.Receipt, []*shared.Receipt) { + // If there's no projected transaction ID then generate one. If + // the transaction succeeds then we'll set it back to "" so that + // we generate a new one next time. If it fails, we'll preserve + // it so that we retry with the same transaction ID. + oq.transactionIDMutex.Lock() + if oq.transactionID == "" { + now := gomatrixserverlib.AsTimestamp(time.Now()) + oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount())) + } + oq.transactionIDMutex.Unlock() + + t := gomatrixserverlib.Transaction{ + PDUs: []json.RawMessage{}, + EDUs: []gomatrixserverlib.EDU{}, + } + t.Origin = oq.origin + t.Destination = oq.destination + t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now()) + t.TransactionID = oq.transactionID + + var pduReceipts []*shared.Receipt + var eduReceipts []*shared.Receipt + + // Go through PDUs that we retrieved from the database, if any, + // and add them into the transaction. + for _, pdu := range pdus { + // These should never be nil. + if pdu == nil || pdu.pdu == nil { + continue + } + // Append the JSON of the event, since this is a json.RawMessage type in the + // gomatrixserverlib.Transaction struct + t.PDUs = append(t.PDUs, pdu.pdu.JSON()) + pduReceipts = append(pduReceipts, pdu.receipt) + } + + // Do the same for pending EDUS in the queue. + for _, edu := range edus { + // These should never be nil. + if edu == nil || edu.edu == nil { + continue + } + t.EDUs = append(t.EDUs, *edu.edu) + eduReceipts = append(eduReceipts, edu.receipt) + } + + return t, pduReceipts, eduReceipts +} + +// blacklistDestination removes all pending PDUs and EDUs that have been cached +// and deletes this queue. +func (oq *destinationQueue) blacklistDestination() { + // It's been suggested that we should give up because the backoff + // has exceeded a maximum allowable value. Clean up the in-memory + // buffers at this point. The PDU clean-up is already on a defer. + logrus.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination) + + oq.pendingMutex.Lock() + for i := range oq.pendingPDUs { + oq.pendingPDUs[i] = nil + } + for i := range oq.pendingEDUs { + oq.pendingEDUs[i] = nil + } + oq.pendingPDUs = nil + oq.pendingEDUs = nil + oq.pendingMutex.Unlock() + + // Delete this queue as no more messages will be sent to this + // destination until it is no longer blacklisted. + oq.statistics.AssignBackoffNotifier(nil) + oq.queues.clearQueue(oq) +} + +// handleTransactionSuccess updates the cached event queues as well as the success and +// backoff information for this server. +func (oq *destinationQueue) handleTransactionSuccess(pduCount int, eduCount int) { + // If we successfully sent the transaction then clear out + // the pending events and EDUs, and wipe our transaction ID. + oq.statistics.Success() + oq.pendingMutex.Lock() + defer oq.pendingMutex.Unlock() + + for i := range oq.pendingPDUs[:pduCount] { + oq.pendingPDUs[i] = nil + } + for i := range oq.pendingEDUs[:eduCount] { + oq.pendingEDUs[i] = nil + } + oq.pendingPDUs = oq.pendingPDUs[pduCount:] + oq.pendingEDUs = oq.pendingEDUs[eduCount:] + + if len(oq.pendingPDUs) > 0 || len(oq.pendingEDUs) > 0 { + select { + case oq.notify <- struct{}{}: + default: + } } } diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index 8245aa5bd..328334379 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -24,6 +24,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -162,23 +163,25 @@ func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *d if !ok || oq == nil { destinationQueueTotal.Inc() oq = &destinationQueue{ - queues: oqs, - db: oqs.db, - process: oqs.process, - rsAPI: oqs.rsAPI, - origin: oqs.origin, - destination: destination, - client: oqs.client, - statistics: oqs.statistics.ForServer(destination), - notify: make(chan struct{}, 1), - interruptBackoff: make(chan bool), - signing: oqs.signing, + queues: oqs, + db: oqs.db, + process: oqs.process, + rsAPI: oqs.rsAPI, + origin: oqs.origin, + destination: destination, + client: oqs.client, + statistics: oqs.statistics.ForServer(destination), + notify: make(chan struct{}, 1), + signing: oqs.signing, } + oq.statistics.AssignBackoffNotifier(oq.handleBackoffNotifier) oqs.queues[destination] = oq } return oq } +// clearQueue removes the queue for the provided destination from the +// set of destination queues. func (oqs *OutgoingQueues) clearQueue(oq *destinationQueue) { oqs.queuesMutex.Lock() defer oqs.queuesMutex.Unlock() @@ -245,11 +248,25 @@ func (oqs *OutgoingQueues) SendEvent( } for destination := range destmap { - if queue := oqs.getQueue(destination); queue != nil { + if queue := oqs.getQueue(destination); queue != nil && !queue.statistics.Blacklisted() { queue.sendEvent(ev, nid) + } else { + delete(destmap, destination) } } + // Create a database entry that associates the given PDU NID with + // this destinations queue. We'll then be able to retrieve the PDU + // later. + if err := oqs.db.AssociatePDUWithDestinations( + oqs.process.Context(), + destmap, + nid, // NIDs from federationapi_queue_json table + ); err != nil { + logrus.WithError(err).Errorf("failed to associate PDUs %q with destinations", nid) + return err + } + return nil } @@ -319,11 +336,27 @@ func (oqs *OutgoingQueues) SendEDU( } for destination := range destmap { - if queue := oqs.getQueue(destination); queue != nil { + if queue := oqs.getQueue(destination); queue != nil && !queue.statistics.Blacklisted() { queue.sendEDU(e, nid) + } else { + delete(destmap, destination) } } + // Create a database entry that associates the given PDU NID with + // this destination queue. We'll then be able to retrieve the PDU + // later. + if err := oqs.db.AssociateEDUWithDestinations( + oqs.process.Context(), + destmap, // the destination server name + nid, // NIDs from federationapi_queue_json table + e.Type, + nil, // this will use the default expireEDUTypes map + ); err != nil { + logrus.WithError(err).Errorf("failed to associate EDU with destinations") + return err + } + return nil } @@ -332,7 +365,9 @@ func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) { if oqs.disabled { return } + oqs.statistics.ForServer(srv).RemoveBlacklist() if queue := oqs.getQueue(srv); queue != nil { + queue.statistics.ClearBackoff() queue.wakeQueueIfNeeded() } } diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go new file mode 100644 index 000000000..a1b280103 --- /dev/null +++ b/federationapi/queue/queue_test.go @@ -0,0 +1,1060 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "testing" + "time" + + "go.uber.org/atomic" + "gotest.tools/v3/poll" + + "github.com/matrix-org/gomatrixserverlib" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + + "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/statistics" + "github.com/matrix-org/dendrite/federationapi/storage" + "github.com/matrix-org/dendrite/federationapi/storage/shared" + rsapi "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" +) + +func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *process.ProcessContext, func()) { + if realDatabase { + // Real Database/s + b, baseClose := testrig.CreateBaseDendrite(t, dbType) + connStr, dbClose := test.PrepareDBConnectionString(t, dbType) + db, err := storage.NewDatabase(b, &config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, b.Caches, b.Cfg.Global.ServerName) + if err != nil { + t.Fatalf("NewDatabase returned %s", err) + } + return db, b.ProcessContext, func() { + dbClose() + baseClose() + } + } else { + // Fake Database + db := createDatabase() + b := struct { + ProcessContext *process.ProcessContext + }{ProcessContext: process.NewProcessContext()} + return db, b.ProcessContext, func() {} + } +} + +func createDatabase() storage.Database { + return &fakeDatabase{ + pendingPDUServers: make(map[gomatrixserverlib.ServerName]struct{}), + pendingEDUServers: make(map[gomatrixserverlib.ServerName]struct{}), + blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}), + pendingPDUs: make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent), + pendingEDUs: make(map[*shared.Receipt]*gomatrixserverlib.EDU), + associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}), + associatedEDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}), + } +} + +type fakeDatabase struct { + storage.Database + dbMutex sync.Mutex + pendingPDUServers map[gomatrixserverlib.ServerName]struct{} + pendingEDUServers map[gomatrixserverlib.ServerName]struct{} + blacklistedServers map[gomatrixserverlib.ServerName]struct{} + pendingPDUs map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent + pendingEDUs map[*shared.Receipt]*gomatrixserverlib.EDU + associatedPDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} + associatedEDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} +} + +var nidMutex sync.Mutex +var nid = int64(0) + +func (d *fakeDatabase) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var event gomatrixserverlib.HeaderedEvent + if err := json.Unmarshal([]byte(js), &event); err == nil { + nidMutex.Lock() + defer nidMutex.Unlock() + nid++ + receipt := shared.NewReceipt(nid) + d.pendingPDUs[&receipt] = &event + return &receipt, nil + } + + var edu gomatrixserverlib.EDU + if err := json.Unmarshal([]byte(js), &edu); err == nil { + nidMutex.Lock() + defer nidMutex.Unlock() + nid++ + receipt := shared.NewReceipt(nid) + d.pendingEDUs[&receipt] = &edu + return &receipt, nil + } + + return nil, errors.New("Failed to determine type of json to store") +} + +func (d *fakeDatabase) GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + pduCount := 0 + pdus = make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent) + if receipts, ok := d.associatedPDUs[serverName]; ok { + for receipt := range receipts { + if event, ok := d.pendingPDUs[receipt]; ok { + pdus[receipt] = event + pduCount++ + if pduCount == limit { + break + } + } + } + } + return pdus, nil +} + +func (d *fakeDatabase) GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + eduCount := 0 + edus = make(map[*shared.Receipt]*gomatrixserverlib.EDU) + if receipts, ok := d.associatedEDUs[serverName]; ok { + for receipt := range receipts { + if event, ok := d.pendingEDUs[receipt]; ok { + edus[receipt] = event + eduCount++ + if eduCount == limit { + break + } + } + } + } + return edus, nil +} + +func (d *fakeDatabase) AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if _, ok := d.pendingPDUs[receipt]; ok { + for destination := range destinations { + if _, ok := d.associatedPDUs[destination]; !ok { + d.associatedPDUs[destination] = make(map[*shared.Receipt]struct{}) + } + d.associatedPDUs[destination][receipt] = struct{}{} + } + + return nil + } else { + return errors.New("PDU doesn't exist") + } +} + +func (d *fakeDatabase) AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if _, ok := d.pendingEDUs[receipt]; ok { + for destination := range destinations { + if _, ok := d.associatedEDUs[destination]; !ok { + d.associatedEDUs[destination] = make(map[*shared.Receipt]struct{}) + } + d.associatedEDUs[destination][receipt] = struct{}{} + } + + return nil + } else { + return errors.New("EDU doesn't exist") + } +} + +func (d *fakeDatabase) CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if pdus, ok := d.associatedPDUs[serverName]; ok { + for _, receipt := range receipts { + delete(pdus, receipt) + } + } + + return nil +} + +func (d *fakeDatabase) CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if edus, ok := d.associatedEDUs[serverName]; ok { + for _, receipt := range receipts { + delete(edus, receipt) + } + } + + return nil +} + +func (d *fakeDatabase) GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var count int64 + if pdus, ok := d.associatedPDUs[serverName]; ok { + count = int64(len(pdus)) + } + return count, nil +} + +func (d *fakeDatabase) GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var count int64 + if edus, ok := d.associatedEDUs[serverName]; ok { + count = int64(len(edus)) + } + return count, nil +} + +func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + servers := []gomatrixserverlib.ServerName{} + for server := range d.pendingPDUServers { + servers = append(servers, server) + } + return servers, nil +} + +func (d *fakeDatabase) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + servers := []gomatrixserverlib.ServerName{} + for server := range d.pendingEDUServers { + servers = append(servers, server) + } + return servers, nil +} + +func (d *fakeDatabase) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.blacklistedServers[serverName] = struct{}{} + return nil +} + +func (d *fakeDatabase) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + delete(d.blacklistedServers, serverName) + return nil +} + +func (d *fakeDatabase) RemoveAllServersFromBlacklist() error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.blacklistedServers = make(map[gomatrixserverlib.ServerName]struct{}) + return nil +} + +func (d *fakeDatabase) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + isBlacklisted := false + if _, ok := d.blacklistedServers[serverName]; ok { + isBlacklisted = true + } + + return isBlacklisted, nil +} + +type stubFederationRoomServerAPI struct { + rsapi.FederationRoomserverAPI +} + +func (r *stubFederationRoomServerAPI) QueryServerBannedFromRoom(ctx context.Context, req *rsapi.QueryServerBannedFromRoomRequest, res *rsapi.QueryServerBannedFromRoomResponse) error { + res.Banned = false + return nil +} + +type stubFederationClient struct { + api.FederationClient + shouldTxSucceed bool + txCount atomic.Uint32 +} + +func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) { + var result error + if !f.shouldTxSucceed { + result = fmt.Errorf("transaction failed") + } + + f.txCount.Add(1) + return gomatrixserverlib.RespSend{}, result +} + +func mustCreatePDU(t *testing.T) *gomatrixserverlib.HeaderedEvent { + t.Helper() + content := `{"type":"m.room.message"}` + ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(content), false, gomatrixserverlib.RoomVersionV10) + if err != nil { + t.Fatalf("failed to create event: %v", err) + } + return ev.Headered(gomatrixserverlib.RoomVersionV10) +} + +func mustCreateEDU(t *testing.T) *gomatrixserverlib.EDU { + t.Helper() + return &gomatrixserverlib.EDU{Type: gomatrixserverlib.MTyping} +} + +func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *stubFederationClient, *OutgoingQueues, *process.ProcessContext, func()) { + db, processContext, close := mustCreateFederationDatabase(t, dbType, realDatabase) + + fc := &stubFederationClient{ + shouldTxSucceed: shouldTxSucceed, + txCount: *atomic.NewUint32(0), + } + rs := &stubFederationRoomServerAPI{} + stats := statistics.NewStatistics(db, failuresUntilBlacklist) + signingInfo := &SigningInfo{ + KeyID: "ed21019:auto", + PrivateKey: test.PrivateKeyA, + ServerName: "localhost", + } + queues := NewOutgoingQueues(db, processContext, false, "localhost", fc, rs, &stats, signingInfo) + + return db, fc, queues, processContext, close +} + +func TestSendPDUOnSuccessRemovedFromDB(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == 1 { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendEDUOnSuccessRemovedFromDB(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == 1 { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendPDUOnFailStoredInDB(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + // Wait for 2 backoff attempts to ensure there was adequate time to attempt sending + if fc.txCount.Load() >= 2 { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + return poll.Success() + } + return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendEDUOnFailStoredInDB(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + // Wait for 2 backoff attempts to ensure there was adequate time to attempt sending + if fc.txCount.Load() >= 2 { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + return poll.Success() + } + return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendPDUAgainDoesntInterruptBackoff(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + // Wait for 2 backoff attempts to ensure there was adequate time to attempt sending + if fc.txCount.Load() >= 2 { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + return poll.Success() + } + return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + fc.shouldTxSucceed = true + ev = mustCreatePDU(t) + err = queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + pollEnd := time.Now().Add(1 * time.Second) + immediateCheck := func(log poll.LogT) poll.Result { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Error(fmt.Errorf("The backoff was interrupted early")) + } + if time.Now().After(pollEnd) { + // Allow more than enough time for the backoff to be interrupted before + // reporting that it wasn't. + return poll.Success() + } + return poll.Continue("waiting for events to be removed from database. Currently present PDU: %d", len(data)) + } + poll.WaitOn(t, immediateCheck, poll.WithTimeout(2*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendEDUAgainDoesntInterruptBackoff(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + // Wait for 2 backoff attempts to ensure there was adequate time to attempt sending + if fc.txCount.Load() >= 2 { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + return poll.Success() + } + return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + fc.shouldTxSucceed = true + ev = mustCreateEDU(t) + err = queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + pollEnd := time.Now().Add(1 * time.Second) + immediateCheck := func(log poll.LogT) poll.Result { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Error(fmt.Errorf("The backoff was interrupted early")) + } + if time.Now().After(pollEnd) { + // Allow more than enough time for the backoff to be interrupted before + // reporting that it wasn't. + return poll.Success() + } + return poll.Continue("waiting for events to be removed from database. Currently present EDU: %d", len(data)) + } + poll.WaitOn(t, immediateCheck, poll.WithTimeout(2*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendPDUMultipleFailuresBlacklisted(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendEDUMultipleFailuresBlacklisted(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendPDUBlacklistedWithPriorExternalFailure(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + queues.statistics.ForServer(destination).Failure() + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendEDUBlacklistedWithPriorExternalFailure(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + queues.statistics.ForServer(destination).Failure() + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestRetryServerSendsPDUSuccessfully(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(1) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + // NOTE : getQueue before sending event to ensure we grab the same queue reference + // before it is blacklisted and deleted. + dest := queues.getQueue(destination) + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + checkBlacklisted := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + if !dest.running.Load() { + return poll.Success() + } + return poll.Continue("waiting for queue to stop completely") + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + fc.shouldTxSucceed = true + db.RemoveServerFromBlacklist(destination) + queues.RetryServer(destination) + checkRetry := func(log poll.LogT) poll.Result { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database. Currently present PDU: %d", len(data)) + } + poll.WaitOn(t, checkRetry, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestRetryServerSendsEDUSuccessfully(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(1) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + // NOTE : getQueue before sending event to ensure we grab the same queue reference + // before it is blacklisted and deleted. + dest := queues.getQueue(destination) + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + checkBlacklisted := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + if !dest.running.Load() { + return poll.Success() + } + return poll.Continue("waiting for queue to stop completely") + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + fc.shouldTxSucceed = true + db.RemoveServerFromBlacklist(destination) + queues.RetryServer(destination) + checkRetry := func(log poll.LogT) poll.Result { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database. Currently present EDU: %d", len(data)) + } + poll.WaitOn(t, checkRetry, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendPDUBatches(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + + // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} + // Populate database with > maxPDUsPerTransaction + pduMultiplier := uint32(3) + for i := 0; i < maxPDUsPerTransaction*int(pduMultiplier); i++ { + ev := mustCreatePDU(t) + headeredJSON, _ := json.Marshal(ev) + nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON)) + err := db.AssociatePDUWithDestinations(pc.Context(), destinations, nid) + assert.NoError(t, err, "failed to associate PDU with destinations") + } + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == pduMultiplier+1 { // +1 for the extra SendEvent() + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for all events to be removed from database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for the right amount of send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + // }) +} + +func TestSendEDUBatches(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + + // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} + // Populate database with > maxEDUsPerTransaction + eduMultiplier := uint32(3) + for i := 0; i < maxEDUsPerTransaction*int(eduMultiplier); i++ { + ev := mustCreateEDU(t) + ephemeralJSON, _ := json.Marshal(ev) + nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON)) + err := db.AssociateEDUWithDestinations(pc.Context(), destinations, nid, ev.Type, nil) + assert.NoError(t, err, "failed to associate EDU with destinations") + } + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == eduMultiplier+1 { // +1 for the extra SendEvent() + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for all events to be removed from database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for the right amount of send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + // }) +} + +func TestSendPDUAndEDUBatches(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + + // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} + // Populate database with > maxEDUsPerTransaction + multiplier := uint32(3) + for i := 0; i < maxPDUsPerTransaction*int(multiplier)+1; i++ { + ev := mustCreatePDU(t) + headeredJSON, _ := json.Marshal(ev) + nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON)) + err := db.AssociatePDUWithDestinations(pc.Context(), destinations, nid) + assert.NoError(t, err, "failed to associate PDU with destinations") + } + + for i := 0; i < maxEDUsPerTransaction*int(multiplier); i++ { + ev := mustCreateEDU(t) + ephemeralJSON, _ := json.Marshal(ev) + nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON)) + err := db.AssociateEDUWithDestinations(pc.Context(), destinations, nid, ev.Type, nil) + assert.NoError(t, err, "failed to associate EDU with destinations") + } + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == multiplier+1 { // +1 for the extra SendEvent() + pduData, dbErrPDU := db.GetPendingPDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErrPDU) + eduData, dbErrEDU := db.GetPendingEDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErrEDU) + if len(pduData) == 0 && len(eduData) == 0 { + return poll.Success() + } + return poll.Continue("waiting for all events to be removed from database. Currently present PDU: %d EDU: %d", len(pduData), len(eduData)) + } + return poll.Continue("waiting for the right amount of send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + // }) +} + +func TestExternalFailureBackoffDoesntStartQueue(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + dest := queues.getQueue(destination) + queues.statistics.ForServer(destination).Failure() + destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} + ev := mustCreatePDU(t) + headeredJSON, _ := json.Marshal(ev) + nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON)) + err := db.AssociatePDUWithDestinations(pc.Context(), destinations, nid) + assert.NoError(t, err, "failed to associate PDU with destinations") + + pollEnd := time.Now().Add(3 * time.Second) + runningCheck := func(log poll.LogT) poll.Result { + if dest.running.Load() || fc.txCount.Load() > 0 { + return poll.Error(fmt.Errorf("The queue was started")) + } + if time.Now().After(pollEnd) { + // Allow more than enough time for the queue to be started in the case + // of backoff triggering it to start. + return poll.Success() + } + return poll.Continue("waiting to ensure queue doesn't start.") + } + poll.WaitOn(t, runningCheck, poll.WithTimeout(4*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { + // NOTE : Only one test case against real databases can be run at a time. + t.Parallel() + failuresUntilBlacklist := uint32(1) + destination := gomatrixserverlib.ServerName("remotehost") + destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, dbType, true) + // NOTE : These defers aren't called if go test is killed so the dbs may not get cleaned up. + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + // NOTE : getQueue before sending event to ensure we grab the same queue reference + // before it is blacklisted and deleted. + dest := queues.getQueue(destination) + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + // NOTE : The server can be blacklisted before this, so manually inject the event + // into the database. + edu := mustCreateEDU(t) + ephemeralJSON, _ := json.Marshal(edu) + nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON)) + err = db.AssociateEDUWithDestinations(pc.Context(), destinations, nid, edu.Type, nil) + assert.NoError(t, err, "failed to associate EDU with destinations") + + checkBlacklisted := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + pduData, dbErrPDU := db.GetPendingPDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErrPDU) + eduData, dbErrEDU := db.GetPendingEDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErrEDU) + if len(pduData) == 1 && len(eduData) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + if !dest.running.Load() { + return poll.Success() + } + return poll.Continue("waiting for queue to stop completely") + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for events to be added to database. Currently present PDU: %d EDU: %d", len(pduData), len(eduData)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond)) + + fc.shouldTxSucceed = true + db.RemoveServerFromBlacklist(destination) + queues.RetryServer(destination) + checkRetry := func(log poll.LogT) poll.Result { + pduData, dbErrPDU := db.GetPendingPDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErrPDU) + eduData, dbErrEDU := db.GetPendingEDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErrEDU) + if len(pduData) == 0 && len(eduData) == 0 { + return poll.Success() + } + return poll.Continue("waiting for events to be removed from database. Currently present PDU: %d EDU: %d", len(pduData), len(eduData)) + } + poll.WaitOn(t, checkRetry, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond)) + }) +} diff --git a/federationapi/statistics/statistics.go b/federationapi/statistics/statistics.go index db6d5c735..2ba99112c 100644 --- a/federationapi/statistics/statistics.go +++ b/federationapi/statistics/statistics.go @@ -2,6 +2,7 @@ package statistics import ( "math" + "math/rand" "sync" "time" @@ -20,12 +21,23 @@ type Statistics struct { servers map[gomatrixserverlib.ServerName]*ServerStatistics mutex sync.RWMutex + backoffTimers map[gomatrixserverlib.ServerName]*time.Timer + backoffMutex sync.RWMutex + // How many times should we tolerate consecutive failures before we // just blacklist the host altogether? The backoff is exponential, // so the max time here to attempt is 2**failures seconds. FailuresUntilBlacklist uint32 } +func NewStatistics(db storage.Database, failuresUntilBlacklist uint32) Statistics { + return Statistics{ + DB: db, + FailuresUntilBlacklist: failuresUntilBlacklist, + backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer), + } +} + // ForServer returns server statistics for the given server name. If it // does not exist, it will create empty statistics and return those. func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerStatistics { @@ -45,7 +57,6 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS server = &ServerStatistics{ statistics: s, serverName: serverName, - interrupt: make(chan struct{}), } s.servers[serverName] = server s.mutex.Unlock() @@ -64,29 +75,43 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS // many times we failed etc. It also manages the backoff time and black- // listing a remote host if it remains uncooperative. type ServerStatistics struct { - statistics *Statistics // - serverName gomatrixserverlib.ServerName // - blacklisted atomic.Bool // is the node blacklisted - backoffStarted atomic.Bool // is the backoff started - backoffUntil atomic.Value // time.Time until this backoff interval ends - backoffCount atomic.Uint32 // number of times BackoffDuration has been called - interrupt chan struct{} // interrupts the backoff goroutine - successCounter atomic.Uint32 // how many times have we succeeded? + statistics *Statistics // + serverName gomatrixserverlib.ServerName // + blacklisted atomic.Bool // is the node blacklisted + backoffStarted atomic.Bool // is the backoff started + backoffUntil atomic.Value // time.Time until this backoff interval ends + backoffCount atomic.Uint32 // number of times BackoffDuration has been called + successCounter atomic.Uint32 // how many times have we succeeded? + backoffNotifier func() // notifies destination queue when backoff completes + notifierMutex sync.Mutex } +const maxJitterMultiplier = 1.4 +const minJitterMultiplier = 0.8 + // duration returns how long the next backoff interval should be. func (s *ServerStatistics) duration(count uint32) time.Duration { - return time.Second * time.Duration(math.Exp2(float64(count))) + // Add some jitter to minimise the chance of having multiple backoffs + // ending at the same time. + jitter := rand.Float64()*(maxJitterMultiplier-minJitterMultiplier) + minJitterMultiplier + duration := time.Millisecond * time.Duration(math.Exp2(float64(count))*jitter*1000) + return duration } // cancel will interrupt the currently active backoff. func (s *ServerStatistics) cancel() { s.blacklisted.Store(false) s.backoffUntil.Store(time.Time{}) - select { - case s.interrupt <- struct{}{}: - default: - } + + s.ClearBackoff() +} + +// AssignBackoffNotifier configures the channel to send to when +// a backoff completes. +func (s *ServerStatistics) AssignBackoffNotifier(notifier func()) { + s.notifierMutex.Lock() + defer s.notifierMutex.Unlock() + s.backoffNotifier = notifier } // Success updates the server statistics with a new successful @@ -95,8 +120,8 @@ func (s *ServerStatistics) cancel() { // we will unblacklist it. func (s *ServerStatistics) Success() { s.cancel() - s.successCounter.Inc() s.backoffCount.Store(0) + s.successCounter.Inc() if s.statistics.DB != nil { if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) @@ -105,13 +130,17 @@ func (s *ServerStatistics) Success() { } // Failure marks a failure and starts backing off if needed. -// The next call to BackoffIfRequired will do the right thing -// after this. It will return the time that the current failure +// It will return the time that the current failure // will result in backoff waiting until, and a bool signalling // whether we have blacklisted and therefore to give up. func (s *ServerStatistics) Failure() (time.Time, bool) { + // Return immediately if we have blacklisted this node. + if s.blacklisted.Load() { + return time.Time{}, true + } + // If we aren't already backing off, this call will start - // a new backoff period. Increase the failure counter and + // a new backoff period, increase the failure counter and // start a goroutine which will wait out the backoff and // unset the backoffStarted flag when done. if s.backoffStarted.CompareAndSwap(false, true) { @@ -122,40 +151,48 @@ func (s *ServerStatistics) Failure() (time.Time, bool) { logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName) } } + s.ClearBackoff() return time.Time{}, true } - go func() { - until, ok := s.backoffUntil.Load().(time.Time) - if ok && !until.IsZero() { - select { - case <-time.After(time.Until(until)): - case <-s.interrupt: - } - s.backoffStarted.Store(false) - } - }() + // We're starting a new back off so work out what the next interval + // will be. + count := s.backoffCount.Load() + until := time.Now().Add(s.duration(count)) + s.backoffUntil.Store(until) + + s.statistics.backoffMutex.Lock() + defer s.statistics.backoffMutex.Unlock() + s.statistics.backoffTimers[s.serverName] = time.AfterFunc(time.Until(until), s.backoffFinished) } - // Check if we have blacklisted this node. - if s.blacklisted.Load() { - return time.Now(), true - } + return s.backoffUntil.Load().(time.Time), false +} - // If we're already backing off and we haven't yet surpassed - // the deadline then return that. Repeated calls to Failure - // within a single backoff interval will have no side effects. - if until, ok := s.backoffUntil.Load().(time.Time); ok && !time.Now().After(until) { - return until, false +// ClearBackoff stops the backoff timer for this destination if it is running +// and removes the timer from the backoffTimers map. +func (s *ServerStatistics) ClearBackoff() { + // If the timer is still running then stop it so it's memory is cleaned up sooner. + s.statistics.backoffMutex.Lock() + defer s.statistics.backoffMutex.Unlock() + if timer, ok := s.statistics.backoffTimers[s.serverName]; ok { + timer.Stop() } + delete(s.statistics.backoffTimers, s.serverName) - // We're either backing off and have passed the deadline, or - // we aren't backing off, so work out what the next interval - // will be. - count := s.backoffCount.Load() - until := time.Now().Add(s.duration(count)) - s.backoffUntil.Store(until) - return until, false + s.backoffStarted.Store(false) +} + +// backoffFinished will clear the previous backoff and notify the destination queue. +func (s *ServerStatistics) backoffFinished() { + s.ClearBackoff() + + // Notify the destinationQueue if one is currently running. + s.notifierMutex.Lock() + defer s.notifierMutex.Unlock() + if s.backoffNotifier != nil { + s.backoffNotifier() + } } // BackoffInfo returns information about the current or previous backoff. @@ -174,6 +211,12 @@ func (s *ServerStatistics) Blacklisted() bool { return s.blacklisted.Load() } +// RemoveBlacklist removes the blacklisted status from the server. +func (s *ServerStatistics) RemoveBlacklist() { + s.cancel() + s.backoffCount.Store(0) +} + // SuccessCount returns the number of successful requests. This is // usually useful in constructing transaction IDs. func (s *ServerStatistics) SuccessCount() uint32 { diff --git a/federationapi/statistics/statistics_test.go b/federationapi/statistics/statistics_test.go index 225350b6d..6aa997f44 100644 --- a/federationapi/statistics/statistics_test.go +++ b/federationapi/statistics/statistics_test.go @@ -7,9 +7,7 @@ import ( ) func TestBackoff(t *testing.T) { - stats := Statistics{ - FailuresUntilBlacklist: 7, - } + stats := NewStatistics(nil, 7) server := ServerStatistics{ statistics: &stats, serverName: "test.com", @@ -36,7 +34,7 @@ func TestBackoff(t *testing.T) { // Get the duration. _, blacklist := server.BackoffInfo() - duration := time.Until(until).Round(time.Second) + duration := time.Until(until) // Unset the backoff, or otherwise our next call will think that // there's a backoff in progress and return the same result. @@ -57,8 +55,17 @@ func TestBackoff(t *testing.T) { // Check if the duration is what we expect. t.Logf("Backoff %d is for %s", i, duration) - if wanted := time.Second * time.Duration(math.Exp2(float64(i))); !blacklist && duration != wanted { - t.Fatalf("Backoff %d should have been %s but was %s", i, wanted, duration) + roundingAllowance := 0.01 + minDuration := time.Millisecond * time.Duration(math.Exp2(float64(i))*minJitterMultiplier*1000-roundingAllowance) + maxDuration := time.Millisecond * time.Duration(math.Exp2(float64(i))*maxJitterMultiplier*1000+roundingAllowance) + var inJitterRange bool + if duration >= minDuration && duration <= maxDuration { + inJitterRange = true + } else { + inJitterRange = false + } + if !blacklist && !inJitterRange { + t.Fatalf("Backoff %d should have been between %s and %s but was %s", i, minDuration, maxDuration, duration) } } } diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index b8109b432..09098cd1e 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -18,9 +18,10 @@ import ( "context" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/federationapi/types" - "github.com/matrix-org/gomatrixserverlib" ) type Database interface { @@ -38,8 +39,8 @@ type Database interface { GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) - AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error - AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error + AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error + AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 9e40f311c..6afb313a8 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -52,6 +52,10 @@ type Receipt struct { nid int64 } +func NewReceipt(nid int64) Receipt { + return Receipt{nid: nid} +} + func (r *Receipt) String() string { return fmt.Sprintf("%d", r.nid) } diff --git a/federationapi/storage/shared/storage_edus.go b/federationapi/storage/shared/storage_edus.go index e0c740c11..c796d2f8f 100644 --- a/federationapi/storage/shared/storage_edus.go +++ b/federationapi/storage/shared/storage_edus.go @@ -38,9 +38,9 @@ var defaultExpireEDUTypes = map[string]time.Duration{ // AssociateEDUWithDestination creates an association that the // destination queues will use to determine which JSON blobs to send // to which servers. -func (d *Database) AssociateEDUWithDestination( +func (d *Database) AssociateEDUWithDestinations( ctx context.Context, - serverName gomatrixserverlib.ServerName, + destinations map[gomatrixserverlib.ServerName]struct{}, receipt *Receipt, eduType string, expireEDUTypes map[string]time.Duration, @@ -59,17 +59,18 @@ func (d *Database) AssociateEDUWithDestination( expiresAt = 0 } return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - if err := d.FederationQueueEDUs.InsertQueueEDU( - ctx, // context - txn, // SQL transaction - eduType, // EDU type for coalescing - serverName, // destination server name - receipt.nid, // NID from the federationapi_queue_json table - expiresAt, // The timestamp this EDU will expire - ); err != nil { - return fmt.Errorf("InsertQueueEDU: %w", err) + var err error + for destination := range destinations { + err = d.FederationQueueEDUs.InsertQueueEDU( + ctx, // context + txn, // SQL transaction + eduType, // EDU type for coalescing + destination, // destination server name + receipt.nid, // NID from the federationapi_queue_json table + expiresAt, // The timestamp this EDU will expire + ) } - return nil + return err }) } diff --git a/federationapi/storage/shared/storage_pdus.go b/federationapi/storage/shared/storage_pdus.go index 5a12c388a..dc37d7507 100644 --- a/federationapi/storage/shared/storage_pdus.go +++ b/federationapi/storage/shared/storage_pdus.go @@ -27,23 +27,23 @@ import ( // AssociatePDUWithDestination creates an association that the // destination queues will use to determine which JSON blobs to send // to which servers. -func (d *Database) AssociatePDUWithDestination( +func (d *Database) AssociatePDUWithDestinations( ctx context.Context, - transactionID gomatrixserverlib.TransactionID, - serverName gomatrixserverlib.ServerName, + destinations map[gomatrixserverlib.ServerName]struct{}, receipt *Receipt, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - if err := d.FederationQueuePDUs.InsertQueuePDU( - ctx, // context - txn, // SQL transaction - transactionID, // transaction ID - serverName, // destination server name - receipt.nid, // NID from the federationapi_queue_json table - ); err != nil { - return fmt.Errorf("InsertQueuePDU: %w", err) + var err error + for destination := range destinations { + err = d.FederationQueuePDUs.InsertQueuePDU( + ctx, // context + txn, // SQL transaction + "", // transaction ID + destination, // destination server name + receipt.nid, // NID from the federationapi_queue_json table + ) } - return nil + return err }) } diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go index 3b0268e55..6272fd2b1 100644 --- a/federationapi/storage/storage_test.go +++ b/federationapi/storage/storage_test.go @@ -35,6 +35,7 @@ func TestExpireEDUs(t *testing.T) { } ctx := context.Background() + destinations := map[gomatrixserverlib.ServerName]struct{}{"localhost": {}} test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateFederationDatabase(t, dbType) defer close() @@ -43,7 +44,7 @@ func TestExpireEDUs(t *testing.T) { receipt, err := db.StoreJSON(ctx, "{}") assert.NoError(t, err) - err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, gomatrixserverlib.MReceipt, expireEDUTypes) + err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, gomatrixserverlib.MReceipt, expireEDUTypes) assert.NoError(t, err) } // add data without expiry @@ -51,7 +52,7 @@ func TestExpireEDUs(t *testing.T) { assert.NoError(t, err) // m.read_marker gets the default expiry of 24h, so won't be deleted further down in this test - err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, "m.read_marker", expireEDUTypes) + err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, "m.read_marker", expireEDUTypes) assert.NoError(t, err) // Delete expired EDUs @@ -67,7 +68,7 @@ func TestExpireEDUs(t *testing.T) { receipt, err = db.StoreJSON(ctx, "{}") assert.NoError(t, err) - err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, gomatrixserverlib.MDirectToDevice, expireEDUTypes) + err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, gomatrixserverlib.MDirectToDevice, expireEDUTypes) assert.NoError(t, err) err = db.DeleteExpiredEDUs(ctx) diff --git a/go.mod b/go.mod index 911d36c1c..7f9bb3897 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20221018085104-a72a83f0e19a + github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 @@ -50,6 +50,7 @@ require ( golang.org/x/term v0.0.0-20220919170432-7a66f970e087 gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/yaml.v2 v2.4.0 + gotest.tools/v3 v3.4.0 nhooyr.io/websocket v1.8.7 ) @@ -127,7 +128,6 @@ require ( gopkg.in/macaroon.v2 v2.1.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - gotest.tools/v3 v3.4.0 // indirect ) go 1.18 diff --git a/go.sum b/go.sum index a141fc9b4..5cce7e0d8 100644 --- a/go.sum +++ b/go.sum @@ -387,8 +387,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221018085104-a72a83f0e19a h1:bQKHk3AWlgm7XhzPhuU3Iw3pUptW5l1DR/1y0o7zCKQ= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221018085104-a72a83f0e19a/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a h1:6rJFN5NBuzZ7h5meYkLtXKa6VFZfDc8oVXHd4SDXr5o= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3 h1:lzkSQvBv8TuqKJCPoVwOVvEnARTlua5rrNy/Qw2Vxeo= github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 06fc4987c..49ef03054 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -250,6 +250,7 @@ func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *ap // nolint:gocyclo func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error { + var respMu sync.Mutex res.DeviceKeys = make(map[string]map[string]json.RawMessage) res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey) res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey) @@ -329,7 +330,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques } // attempt to satisfy key queries from the local database first as we should get device updates pushed to us - domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys) + domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, &respMu, domainToDeviceKeys) if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 { // perform key queries for remote devices a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys) @@ -407,7 +408,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques } func (a *KeyInternalAPI) remoteKeysFromDatabase( - ctx context.Context, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string, + ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string, ) map[string]map[string][]string { fetchRemote := make(map[string]map[string][]string) for domain, userToDeviceMap := range domainToDeviceKeys { @@ -415,7 +416,7 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase( // we can't safely return keys from the db when all devices are requested as we don't // know if one has just been added. if len(deviceIDs) > 0 { - err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, deviceIDs) + err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, deviceIDs) if err == nil { continue } @@ -471,7 +472,9 @@ func (a *KeyInternalAPI) queryRemoteKeys( close(resultCh) }() - for result := range resultCh { + processResult := func(result *gomatrixserverlib.RespQueryKeys) { + respMu.Lock() + defer respMu.Unlock() for userID, nest := range result.DeviceKeys { res.DeviceKeys[userID] = make(map[string]json.RawMessage) for deviceID, deviceKey := range nest { @@ -494,6 +497,10 @@ func (a *KeyInternalAPI) queryRemoteKeys( // TODO: do we want to persist these somewhere now // that we have fetched them? } + + for result := range resultCh { + processResult(result) + } } func (a *KeyInternalAPI) queryRemoteKeysOnServer( @@ -541,9 +548,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( } // refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this // user so the fact that we're populating all devices here isn't a problem so long as we have devices. - respMu.Lock() - err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, nil) - respMu.Unlock() + err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil) if err != nil { logrus.WithFields(logrus.Fields{ logrus.ErrorKey: err, @@ -567,25 +572,26 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( res.Failures[serverName] = map[string]interface{}{ "message": err.Error(), } + respMu.Unlock() // last ditch, use the cache only. This is good for when clients hit /keys/query and the remote server // is down, better to return something than nothing at all. Clients can know about the failure by // inspecting the failures map though so they can know it's a cached response. for userID, dkeys := range devKeys { // drop the error as it's already a failure at this point - _ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, dkeys) + _ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, dkeys) } // Sytest expects no failures, if we still could retrieve keys, e.g. from local cache + respMu.Lock() if len(res.DeviceKeys) > 0 { delete(res.Failures, serverName) } respMu.Unlock() - } func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( - ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string, + ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string, ) error { keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) // if we can't query the db or there are fewer keys than requested, fetch from remote. @@ -598,9 +604,11 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( if len(deviceIDs) == 0 && len(keys) == 0 { return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID) } + respMu.Lock() if res.DeviceKeys[userID] == nil { res.DeviceKeys[userID] = make(map[string]json.RawMessage) } + respMu.Unlock() for _, key := range keys { if len(key.KeyJSON) == 0 { @@ -610,7 +618,9 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct { DisplayName string `json:"device_display_name,omitempty"` }{key.DisplayName}) + respMu.Lock() res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON + respMu.Unlock() } return nil } diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index cfbb05327..f767615c8 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -428,6 +428,13 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent( return } + // Only notify clients about retired invite events, if the user didn't accept the invite. + // The PDU stream will also receive an event about accepting the invitation, so there should + // be a "smooth" transition from invite -> join, and not invite -> leave -> join + if msg.Membership == gomatrixserverlib.Join { + return + } + // Notify any active sync requests that the invite has been retired. s.inviteStream.Advance(pduPos) s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, msg.TargetUserID) diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index b562e6804..0ecbdf4d2 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -28,8 +28,9 @@ import ( "github.com/matrix-org/dendrite/syncapi/types" "github.com/lib/pq" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/sqlutil" ) const outputRoomEventsSchema = ` @@ -133,7 +134,7 @@ const updateEventJSONSQL = "" + "UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2" // In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id). -const selectStateInRangeSQL = "" + +const selectStateInRangeFilteredSQL = "" + "SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids, history_visibility" + " FROM syncapi_output_room_events" + " WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" + @@ -146,6 +147,15 @@ const selectStateInRangeSQL = "" + " ORDER BY id ASC" + " LIMIT $9" +// In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id). +const selectStateInRangeSQL = "" + + "SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids, history_visibility" + + " FROM syncapi_output_room_events" + + " WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" + + " AND room_id = ANY($3)" + + " ORDER BY id ASC" + + " LIMIT $4" + const deleteEventsForRoomSQL = "" + "DELETE FROM syncapi_output_room_events WHERE room_id = $1" @@ -171,20 +181,21 @@ const selectContextAfterEventSQL = "" + const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE id > $1 AND type = ANY($2) ORDER BY id ASC LIMIT $3" type outputRoomEventsStatements struct { - insertEventStmt *sql.Stmt - selectEventsStmt *sql.Stmt - selectEventsWitFilterStmt *sql.Stmt - selectMaxEventIDStmt *sql.Stmt - selectRecentEventsStmt *sql.Stmt - selectRecentEventsForSyncStmt *sql.Stmt - selectEarlyEventsStmt *sql.Stmt - selectStateInRangeStmt *sql.Stmt - updateEventJSONStmt *sql.Stmt - deleteEventsForRoomStmt *sql.Stmt - selectContextEventStmt *sql.Stmt - selectContextBeforeEventStmt *sql.Stmt - selectContextAfterEventStmt *sql.Stmt - selectSearchStmt *sql.Stmt + insertEventStmt *sql.Stmt + selectEventsStmt *sql.Stmt + selectEventsWitFilterStmt *sql.Stmt + selectMaxEventIDStmt *sql.Stmt + selectRecentEventsStmt *sql.Stmt + selectRecentEventsForSyncStmt *sql.Stmt + selectEarlyEventsStmt *sql.Stmt + selectStateInRangeFilteredStmt *sql.Stmt + selectStateInRangeStmt *sql.Stmt + updateEventJSONStmt *sql.Stmt + deleteEventsForRoomStmt *sql.Stmt + selectContextEventStmt *sql.Stmt + selectContextBeforeEventStmt *sql.Stmt + selectContextAfterEventStmt *sql.Stmt + selectSearchStmt *sql.Stmt } func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { @@ -214,6 +225,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { {&s.selectRecentEventsStmt, selectRecentEventsSQL}, {&s.selectRecentEventsForSyncStmt, selectRecentEventsForSyncSQL}, {&s.selectEarlyEventsStmt, selectEarlyEventsSQL}, + {&s.selectStateInRangeFilteredStmt, selectStateInRangeFilteredSQL}, {&s.selectStateInRangeStmt, selectStateInRangeSQL}, {&s.updateEventJSONStmt, updateEventJSONSQL}, {&s.deleteEventsForRoomStmt, deleteEventsForRoomSQL}, @@ -240,17 +252,28 @@ func (s *outputRoomEventsStatements) SelectStateInRange( ctx context.Context, txn *sql.Tx, r types.Range, stateFilter *gomatrixserverlib.StateFilter, roomIDs []string, ) (map[string]map[string]bool, map[string]types.StreamEvent, error) { - stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt) - senders, notSenders := getSendersStateFilterFilter(stateFilter) - rows, err := stmt.QueryContext( - ctx, r.Low(), r.High(), pq.StringArray(roomIDs), - pq.StringArray(senders), - pq.StringArray(notSenders), - pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)), - pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)), - stateFilter.ContainsURL, - stateFilter.Limit, - ) + var rows *sql.Rows + var err error + if stateFilter != nil { + stmt := sqlutil.TxStmt(txn, s.selectStateInRangeFilteredStmt) + senders, notSenders := getSendersStateFilterFilter(stateFilter) + rows, err = stmt.QueryContext( + ctx, r.Low(), r.High(), pq.StringArray(roomIDs), + pq.StringArray(senders), + pq.StringArray(notSenders), + pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)), + pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)), + stateFilter.ContainsURL, + stateFilter.Limit, + ) + } else { + stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt) + rows, err = stmt.QueryContext( + ctx, r.Low(), r.High(), pq.StringArray(roomIDs), + r.High()-r.Low(), + ) + } + if err != nil { return nil, nil, err } diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index cb61c1c26..1f66ccc0e 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -5,10 +5,11 @@ import ( "database/sql" "fmt" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" ) type DatabaseTransaction struct { @@ -277,6 +278,7 @@ func (d *DatabaseTransaction) GetBackwardTopologyPos( // exclusive of oldPos, inclusive of newPos, for the rooms in which // the user has new membership events. // A list of joined room IDs is also returned in case the caller needs it. +// nolint:gocyclo func (d *DatabaseTransaction) GetStateDeltas( ctx context.Context, device *userapi.Device, r types.Range, userID string, @@ -311,7 +313,7 @@ func (d *DatabaseTransaction) GetStateDeltas( } // get all the state events ever (i.e. for all available rooms) between these two positions - stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs) + stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, nil, allRoomIDs) if err != nil { if err == sql.ErrNoRows { return nil, nil, nil @@ -326,6 +328,22 @@ func (d *DatabaseTransaction) GetStateDeltas( return nil, nil, err } + // get all the state events ever (i.e. for all available rooms) between these two positions + stateNeededFiltered, eventMapFiltered, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } + return nil, nil, err + } + stateFiltered, err := d.fetchStateEvents(ctx, d.txn, stateNeededFiltered, eventMapFiltered) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } + return nil, nil, err + } + // find out which rooms this user is peeking, if any. // We do this before joins so any peeks get overwritten peeks, err := d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, device.ID, r) @@ -371,6 +389,7 @@ func (d *DatabaseTransaction) GetStateDeltas( // If our membership is now join but the previous membership wasn't // then this is a "join transition", so we'll insert this room. if prevMembership != membership { + newlyJoinedRooms[roomID] = true // Get the full room state, as we'll send that down for a newly // joined room instead of a delta. var s []types.StreamEvent @@ -383,8 +402,7 @@ func (d *DatabaseTransaction) GetStateDeltas( // Add the information for this room into the state so that // it will get added with all of the rest of the joined rooms. - state[roomID] = s - newlyJoinedRooms[roomID] = true + stateFiltered[roomID] = s } // We won't add joined rooms into the delta at this point as they @@ -395,7 +413,7 @@ func (d *DatabaseTransaction) GetStateDeltas( deltas = append(deltas, types.StateDelta{ Membership: membership, MembershipPos: ev.StreamPosition, - StateEvents: d.StreamEventsToEvents(device, stateStreamEvents), + StateEvents: d.StreamEventsToEvents(device, stateFiltered[roomID]), RoomID: roomID, }) break @@ -407,7 +425,7 @@ func (d *DatabaseTransaction) GetStateDeltas( for _, joinedRoomID := range joinedRoomIDs { deltas = append(deltas, types.StateDelta{ Membership: gomatrixserverlib.Join, - StateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]), + StateEvents: d.StreamEventsToEvents(device, stateFiltered[joinedRoomID]), RoomID: joinedRoomID, NewlyJoined: newlyJoinedRooms[joinedRoomID], }) diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index d6a674b9c..77c692ff0 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -29,8 +29,9 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/sqlutil" ) const outputRoomEventsSchema = ` @@ -189,21 +190,36 @@ func (s *outputRoomEventsStatements) SelectStateInRange( for _, roomID := range roomIDs { inputParams = append(inputParams, roomID) } - stmt, params, err := prepareWithFilters( - s.db, txn, stmtSQL, inputParams, - stateFilter.Senders, stateFilter.NotSenders, - stateFilter.Types, stateFilter.NotTypes, - nil, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderAsc, + var ( + stmt *sql.Stmt + params []any + err error ) + if stateFilter != nil { + stmt, params, err = prepareWithFilters( + s.db, txn, stmtSQL, inputParams, + stateFilter.Senders, stateFilter.NotSenders, + stateFilter.Types, stateFilter.NotTypes, + nil, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderAsc, + ) + } else { + stmt, params, err = prepareWithFilters( + s.db, txn, stmtSQL, inputParams, + nil, nil, + nil, nil, + nil, nil, int(r.High()-r.Low()), FilterOrderAsc, + ) + } if err != nil { return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err) } + defer internal.CloseAndLogIfError(ctx, stmt, "selectStateInRange: stmt.close() failed") rows, err := stmt.QueryContext(ctx, params...) if err != nil { return nil, nil, err } - defer rows.Close() // nolint: errcheck + defer internal.CloseAndLogIfError(ctx, rows, "selectStateInRange: rows.close() failed") // Fetch all the state change events for all rooms between the two positions then loop each event and: // - Keep a cache of the event by ID (99% of state change events are for the event itself) // - For each room ID, build up an array of event IDs which represents cumulative adds/removes @@ -269,6 +285,7 @@ func (s *outputRoomEventsStatements) SelectMaxEventID( ) (id int64, err error) { var nullableID sql.NullInt64 stmt := sqlutil.TxStmt(txn, s.selectMaxEventIDStmt) + defer internal.CloseAndLogIfError(ctx, stmt, "SelectMaxEventID: stmt.close() failed") err = stmt.QueryRowContext(ctx).Scan(&nullableID) if nullableID.Valid { id = nullableID.Int64 @@ -323,6 +340,7 @@ func (s *outputRoomEventsStatements) InsertEvent( return 0, err } insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) + defer internal.CloseAndLogIfError(ctx, insertStmt, "InsertEvent: stmt.close() failed") _, err = insertStmt.ExecContext( ctx, streamPos, @@ -367,6 +385,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( if err != nil { return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err) } + defer internal.CloseAndLogIfError(ctx, stmt, "selectRecentEvents: stmt.close() failed") rows, err := stmt.QueryContext(ctx, params...) if err != nil { @@ -415,6 +434,8 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( if err != nil { return nil, fmt.Errorf("s.prepareWithFilters: %w", err) } + defer internal.CloseAndLogIfError(ctx, stmt, "SelectEarlyEvents: stmt.close() failed") + rows, err := stmt.QueryContext(ctx, params...) if err != nil { return nil, err @@ -456,6 +477,8 @@ func (s *outputRoomEventsStatements) SelectEvents( if err != nil { return nil, err } + defer internal.CloseAndLogIfError(ctx, stmt, "SelectEvents: stmt.close() failed") + rows, err := stmt.QueryContext(ctx, params...) if err != nil { return nil, err @@ -558,6 +581,10 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent( filter.Types, filter.NotTypes, nil, filter.ContainsURL, filter.Limit, FilterOrderDesc, ) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, stmt, "SelectContextBeforeEvent: stmt.close() failed") rows, err := stmt.QueryContext(ctx, params...) if err != nil { @@ -596,6 +623,10 @@ func (s *outputRoomEventsStatements) SelectContextAfterEvent( filter.Types, filter.NotTypes, nil, filter.ContainsURL, filter.Limit, FilterOrderAsc, ) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, stmt, "SelectContextAfterEvent: stmt.close() failed") rows, err := stmt.QueryContext(ctx, params...) if err != nil { diff --git a/syncapi/storage/tables/memberships_test.go b/syncapi/storage/tables/memberships_test.go new file mode 100644 index 000000000..0cee7f5a5 --- /dev/null +++ b/syncapi/storage/tables/memberships_test.go @@ -0,0 +1,198 @@ +package tables_test + +import ( + "context" + "database/sql" + "reflect" + "sort" + "testing" + "time" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/storage/postgres" + "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/test" +) + +func newMembershipsTable(t *testing.T, dbType test.DBType) (tables.Memberships, *sql.DB, func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + if err != nil { + t.Fatalf("failed to open db: %s", err) + } + + var tab tables.Memberships + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresMembershipsTable(db) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSqliteMembershipsTable(db) + } + if err != nil { + t.Fatalf("failed to make new table: %s", err) + } + return tab, db, close +} + +func TestMembershipsTable(t *testing.T) { + + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + + // Create users + var userEvents []*gomatrixserverlib.HeaderedEvent + users := []string{alice.ID} + for _, x := range room.CurrentState() { + if x.StateKeyEquals(alice.ID) { + if _, err := x.Membership(); err == nil { + userEvents = append(userEvents, x) + break + } + } + } + + if len(userEvents) == 0 { + t.Fatalf("didn't find creator membership event") + } + + for i := 0; i < 10; i++ { + u := test.NewUser(t) + users = append(users, u.ID) + + ev := room.CreateAndInsert(t, u, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(u.ID)) + userEvents = append(userEvents, ev) + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + table, _, close := newMembershipsTable(t, dbType) + defer close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + for _, ev := range userEvents { + if err := table.UpsertMembership(ctx, nil, ev, types.StreamPosition(ev.Depth()), 1); err != nil { + t.Fatalf("failed to upsert membership: %s", err) + } + } + + testUpsert(t, ctx, table, userEvents[0], alice, room) + testMembershipCount(t, ctx, table, room) + testHeroes(t, ctx, table, alice, room, users) + }) +} + +func testHeroes(t *testing.T, ctx context.Context, table tables.Memberships, user *test.User, room *test.Room, users []string) { + + // Re-slice and sort the expected users + users = users[1:] + sort.Strings(users) + type testCase struct { + name string + memberships []string + wantHeroes []string + } + + testCases := []testCase{ + {name: "no memberships queried", memberships: []string{}}, + {name: "joined memberships queried should be limited", memberships: []string{gomatrixserverlib.Join}, wantHeroes: users[:5]}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := table.SelectHeroes(ctx, nil, room.ID, user.ID, tc.memberships) + if err != nil { + t.Fatalf("unable to select heroes: %s", err) + } + if gotLen := len(got); gotLen != len(tc.wantHeroes) { + t.Fatalf("expected %d heroes, got %d", len(tc.wantHeroes), gotLen) + } + + if !reflect.DeepEqual(got, tc.wantHeroes) { + t.Fatalf("expected heroes to be %+v, got %+v", tc.wantHeroes, got) + } + }) + } +} + +func testMembershipCount(t *testing.T, ctx context.Context, table tables.Memberships, room *test.Room) { + t.Run("membership counts are correct", func(t *testing.T) { + // After 10 events, we should have 6 users (5 create related [incl. one member event], 5 member events = 6 users) + count, err := table.SelectMembershipCount(ctx, nil, room.ID, gomatrixserverlib.Join, 10) + if err != nil { + t.Fatalf("failed to get membership count: %s", err) + } + expectedCount := 6 + if expectedCount != count { + t.Fatalf("expected member count to be %d, got %d", expectedCount, count) + } + + // After 100 events, we should have all 11 users + count, err = table.SelectMembershipCount(ctx, nil, room.ID, gomatrixserverlib.Join, 100) + if err != nil { + t.Fatalf("failed to get membership count: %s", err) + } + expectedCount = 11 + if expectedCount != count { + t.Fatalf("expected member count to be %d, got %d", expectedCount, count) + } + }) +} + +func testUpsert(t *testing.T, ctx context.Context, table tables.Memberships, membershipEvent *gomatrixserverlib.HeaderedEvent, user *test.User, room *test.Room) { + t.Run("upserting works as expected", func(t *testing.T) { + if err := table.UpsertMembership(ctx, nil, membershipEvent, 1, 1); err != nil { + t.Fatalf("failed to upsert membership: %s", err) + } + membership, pos, err := table.SelectMembershipForUser(ctx, nil, room.ID, user.ID, 1) + if err != nil { + t.Fatalf("failed to select membership: %s", err) + } + expectedPos := 1 + if pos != expectedPos { + t.Fatalf("expected pos to be %d, got %d", expectedPos, pos) + } + if membership != gomatrixserverlib.Join { + t.Fatalf("expected membership to be join, got %s", membership) + } + // Create a new event which gets upserted and should not cause issues + ev := room.CreateAndInsert(t, user, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": gomatrixserverlib.Join, + }, test.WithStateKey(user.ID)) + // Insert the same event again, but with different positions, which should get updated + if err = table.UpsertMembership(ctx, nil, ev, 2, 2); err != nil { + t.Fatalf("failed to upsert membership: %s", err) + } + + // Verify the position got updated + membership, pos, err = table.SelectMembershipForUser(ctx, nil, room.ID, user.ID, 10) + if err != nil { + t.Fatalf("failed to select membership: %s", err) + } + expectedPos = 2 + if pos != expectedPos { + t.Fatalf("expected pos to be %d, got %d", expectedPos, pos) + } + if membership != gomatrixserverlib.Join { + t.Fatalf("expected membership to be join, got %s", membership) + } + + // If we can't find a membership, it should default to leave + if membership, _, err = table.SelectMembershipForUser(ctx, nil, room.ID, user.ID, 1); err != nil { + t.Fatalf("failed to select membership: %s", err) + } + if membership != gomatrixserverlib.Leave { + t.Fatalf("expected membership to be leave, got %s", membership) + } + }) +} diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index 7875ffa35..700f25c10 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -74,21 +74,26 @@ func (p *InviteStreamProvider) IncrementalSync( return to } for roomID := range retiredInvites { - if _, ok := req.Response.Rooms.Join[roomID]; !ok { - lr := types.NewLeaveResponse() - h := sha256.Sum256(append([]byte(roomID), []byte(strconv.FormatInt(int64(to), 10))...)) - lr.Timeline.Events = append(lr.Timeline.Events, gomatrixserverlib.ClientEvent{ - // fake event ID which muxes in the to position - EventID: "$" + base64.RawURLEncoding.EncodeToString(h[:]), - OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), - RoomID: roomID, - Sender: req.Device.UserID, - StateKey: &req.Device.UserID, - Type: "m.room.member", - Content: gomatrixserverlib.RawJSON(`{"membership":"leave"}`), - }) - req.Response.Rooms.Leave[roomID] = lr + if _, ok := req.Response.Rooms.Invite[roomID]; ok { + continue } + if _, ok := req.Response.Rooms.Join[roomID]; ok { + continue + } + lr := types.NewLeaveResponse() + h := sha256.Sum256(append([]byte(roomID), []byte(strconv.FormatInt(int64(to), 10))...)) + lr.Timeline.Events = append(lr.Timeline.Events, gomatrixserverlib.ClientEvent{ + // fake event ID which muxes in the to position + EventID: "$" + base64.RawURLEncoding.EncodeToString(h[:]), + OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), + RoomID: roomID, + Sender: req.Device.UserID, + StateKey: &req.Device.UserID, + Type: "m.room.member", + Content: gomatrixserverlib.RawJSON(`{"membership":"leave"}`), + }) + req.Response.Rooms.Leave[roomID] = lr + } return maxID diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 613ac434f..9ec2b61cd 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -194,7 +194,7 @@ func (p *PDUStreamProvider) IncrementalSync( } } var pos types.StreamPosition - if pos, err = p.addRoomDeltaToResponse(ctx, snapshot, req.Device, newRange, delta, &eventFilter, &stateFilter, req.Response); err != nil { + if pos, err = p.addRoomDeltaToResponse(ctx, snapshot, req.Device, newRange, delta, &eventFilter, &stateFilter, req); err != nil { req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed") if err == context.DeadlineExceeded || err == context.Canceled || err == sql.ErrTxDone { return newPos @@ -225,7 +225,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( delta types.StateDelta, eventFilter *gomatrixserverlib.RoomEventFilter, stateFilter *gomatrixserverlib.StateFilter, - res *types.Response, + req *types.SyncRequest, ) (types.StreamPosition, error) { if delta.MembershipPos > 0 && delta.Membership == gomatrixserverlib.Leave { // make sure we don't leak recent events after the leave event. @@ -290,8 +290,10 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( hasMembershipChange := false for _, recentEvent := range recentStreamEvents { if recentEvent.Type() == gomatrixserverlib.MRoomMember && recentEvent.StateKey() != nil { + if membership, _ := recentEvent.Membership(); membership == gomatrixserverlib.Join { + req.MembershipChanges[*recentEvent.StateKey()] = struct{}{} + } hasMembershipChange = true - break } } @@ -318,9 +320,9 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. - jr.Timeline.Limited = limited && len(events) == len(recentEvents) + jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync) - res.Rooms.Join[delta.RoomID] = jr + req.Response.Rooms.Join[delta.RoomID] = jr case gomatrixserverlib.Peek: jr := types.NewJoinResponse() @@ -329,7 +331,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = limited jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync) - res.Rooms.Peek[delta.RoomID] = jr + req.Response.Rooms.Peek[delta.RoomID] = jr case gomatrixserverlib.Leave: fallthrough // transitions to leave are the same as ban @@ -342,7 +344,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( // didn't "remove" events, return that the response is limited. lr.Timeline.Limited = limited && len(events) == len(recentEvents) lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync) - res.Rooms.Leave[delta.RoomID] = lr + req.Response.Rooms.Leave[delta.RoomID] = lr } return latestPosition, nil diff --git a/syncapi/streams/stream_presence.go b/syncapi/streams/stream_presence.go index 8b87af452..030b7c5d5 100644 --- a/syncapi/streams/stream_presence.go +++ b/syncapi/streams/stream_presence.go @@ -121,7 +121,8 @@ func (p *PresenceStreamProvider) IncrementalSync( prevPresence := pres.(*types.PresenceInternal) currentlyActive := prevPresence.CurrentlyActive() skip := prevPresence.Equals(presence) && currentlyActive && req.Device.UserID != presence.UserID - if skip { + _, membershipChange := req.MembershipChanges[presence.UserID] + if skip && !membershipChange { req.Log.Tracef("Skipping presence, no change (%s)", presence.UserID) continue } diff --git a/syncapi/sync/request.go b/syncapi/sync/request.go index 268ed70c6..620dfdcdb 100644 --- a/syncapi/sync/request.go +++ b/syncapi/sync/request.go @@ -91,15 +91,16 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat }) return &types.SyncRequest{ - Context: req.Context(), // - Log: logger, // - Device: &device, // - Response: types.NewResponse(), // Populated by all streams - Filter: filter, // - Since: since, // - Timeout: timeout, // - Rooms: make(map[string]string), // Populated by the PDU stream - WantFullState: wantFullState, // + Context: req.Context(), // + Log: logger, // + Device: &device, // + Response: types.NewResponse(), // Populated by all streams + Filter: filter, // + Since: since, // + Timeout: timeout, // + Rooms: make(map[string]string), // Populated by the PDU stream + WantFullState: wantFullState, // + MembershipChanges: make(map[string]struct{}), // Populated by the PDU stream }, nil } diff --git a/syncapi/types/provider.go b/syncapi/types/provider.go index 378cafe99..9a533002b 100644 --- a/syncapi/types/provider.go +++ b/syncapi/types/provider.go @@ -4,9 +4,10 @@ import ( "context" "time" - userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" + + userapi "github.com/matrix-org/dendrite/userapi/api" ) type SyncRequest struct { @@ -22,6 +23,8 @@ type SyncRequest struct { // Updated by the PDU stream. Rooms map[string]string // Updated by the PDU stream. + MembershipChanges map[string]struct{} + // Updated by the PDU stream. IgnoredUsers IgnoredUsers } diff --git a/sytest-blacklist b/sytest-blacklist index 634c07cf3..14edf398a 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -22,10 +22,6 @@ Forgotten room messages cannot be paginated Local device key changes get to remote servers with correct prev_id -# Flakey - -Local device key changes appear in /keys/changes - # we don't support groups Remove group category @@ -39,12 +35,6 @@ Events in rooms with AS-hosted room aliases are sent to AS server Inviting an AS-hosted user asks the AS server Accesing an AS-hosted room alias asks the AS server -# Flakey, need additional investigation - -Messages that notify from another user increment notification_count -Messages that highlight from another user increment unread highlight count -Notifications can be viewed with GET /notifications - # More flakey Guest users can join guest_access rooms diff --git a/sytest-whitelist b/sytest-whitelist index 93d447d28..1387838f7 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -746,4 +746,10 @@ Existing members see new member's presence Inbound federation can return missing events for joined visibility outliers whose auth_events are in a different room are correctly rejected Messages that notify from another user increment notification_count -Messages that highlight from another user increment unread highlight count \ No newline at end of file +Messages that highlight from another user increment unread highlight count +Newly joined room has correct timeline in incremental sync +When user joins a room the state is included in the next sync +When user joins a room the state is included in a gapped sync +Messages that notify from another user increment notification_count +Messages that highlight from another user increment unread highlight count +Notifications can be viewed with GET /notifications \ No newline at end of file diff --git a/userapi/api/api.go b/userapi/api/api.go index 66ee9c7c8..eef29144a 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -96,7 +96,7 @@ type ClientUserAPI interface { PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error SetAvatarURL(ctx context.Context, req *PerformSetAvatarURLRequest, res *PerformSetAvatarURLResponse) error - SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *struct{}) error + SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *PerformUpdateDisplayNameResponse) error QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error @@ -579,7 +579,10 @@ type Notification struct { type PerformSetAvatarURLRequest struct { Localpart, AvatarURL string } -type PerformSetAvatarURLResponse struct{} +type PerformSetAvatarURLResponse struct { + Profile *authtypes.Profile `json:"profile"` + Changed bool `json:"changed"` +} type QueryNumericLocalpartResponse struct { ID int64 @@ -606,6 +609,11 @@ type PerformUpdateDisplayNameRequest struct { Localpart, DisplayName string } +type PerformUpdateDisplayNameResponse struct { + Profile *authtypes.Profile `json:"profile"` + Changed bool `json:"changed"` +} + type QueryLocalpartForThreePIDRequest struct { ThreePID, Medium string } diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go index 7e2f69615..90834f7e3 100644 --- a/userapi/api/api_trace.go +++ b/userapi/api/api_trace.go @@ -168,7 +168,7 @@ func (t *UserInternalAPITrace) QueryAccountAvailability(ctx context.Context, req return err } -func (t *UserInternalAPITrace) SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *struct{}) error { +func (t *UserInternalAPITrace) SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *PerformUpdateDisplayNameResponse) error { err := t.Impl.SetDisplayName(ctx, req, res) util.GetLogger(ctx).Infof("SetDisplayName req=%+v res=%+v", js(req), js(res)) return err diff --git a/userapi/consumers/clientapi.go b/userapi/consumers/clientapi.go index c220d35cb..79f1bf06f 100644 --- a/userapi/consumers/clientapi.go +++ b/userapi/consumers/clientapi.go @@ -81,7 +81,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats readPos := msg.Header.Get(jetstream.EventID) evType := msg.Header.Get("type") - if readPos == "" || evType != "m.read" { + if readPos == "" || (evType != "m.read" && evType != "m.read.private") { return true } diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 2f7795dfe..63044eedb 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -170,7 +170,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P return nil } - if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { + if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { return err } @@ -813,7 +813,10 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush } func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error { - return a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL) + profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL) + res.Profile = profile + res.Changed = changed + return err } func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error { @@ -847,8 +850,11 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q } } -func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, _ *struct{}) error { - return a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName) +func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error { + profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName) + res.Profile = profile + res.Changed = changed + return err } func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error { diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index a375d6caa..aa5d46d9f 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -388,7 +388,7 @@ func (h *httpUserInternalAPI) QueryAccountByPassword( func (h *httpUserInternalAPI) SetDisplayName( ctx context.Context, request *api.PerformUpdateDisplayNameRequest, - response *struct{}, + response *api.PerformUpdateDisplayNameResponse, ) error { return httputil.CallInternalRPCAPI( "SetDisplayName", h.apiURL+PerformSetDisplayNamePath, diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index f82fe7b69..28ef26559 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -31,8 +31,8 @@ import ( type Profile interface { GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) - SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error - SetDisplayName(ctx context.Context, localpart string, displayName string) error + SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, bool, error) + SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, bool, error) } type Account interface { diff --git a/userapi/storage/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go index f686127be..2753b23d9 100644 --- a/userapi/storage/postgres/profile_table.go +++ b/userapi/storage/postgres/profile_table.go @@ -44,10 +44,18 @@ const selectProfileByLocalpartSQL = "" + "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" const setAvatarURLSQL = "" + - "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" + "UPDATE userapi_profiles AS new" + + " SET avatar_url = $1" + + " FROM userapi_profiles AS old" + + " WHERE new.localpart = $2" + + " RETURNING new.display_name, old.avatar_url <> new.avatar_url" const setDisplayNameSQL = "" + - "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" + "UPDATE userapi_profiles AS new" + + " SET display_name = $1" + + " FROM userapi_profiles AS old" + + " WHERE new.localpart = $2" + + " RETURNING new.avatar_url, old.display_name <> new.display_name" const selectProfilesBySearchSQL = "" + "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" @@ -100,16 +108,28 @@ func (s *profilesStatements) SelectProfileByLocalpart( func (s *profilesStatements) SetAvatarURL( ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, -) (err error) { - _, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart) - return +) (*authtypes.Profile, bool, error) { + profile := &authtypes.Profile{ + Localpart: localpart, + AvatarURL: avatarURL, + } + var changed bool + stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) + err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName, &changed) + return profile, changed, err } func (s *profilesStatements) SetDisplayName( ctx context.Context, txn *sql.Tx, localpart string, displayName string, -) (err error) { - _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) - return +) (*authtypes.Profile, bool, error) { + profile := &authtypes.Profile{ + Localpart: localpart, + DisplayName: displayName, + } + var changed bool + stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) + err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL, &changed) + return profile, changed, err } func (s *profilesStatements) SelectProfilesBySearch( diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 31e93e833..f8b8d02c9 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -95,20 +95,24 @@ func (d *Database) GetProfileByLocalpart( // localpart. Returns an error if something went wrong with the SQL query func (d *Database) SetAvatarURL( ctx context.Context, localpart string, avatarURL string, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL) +) (profile *authtypes.Profile, changed bool, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL) + return err }) + return } // SetDisplayName updates the display name of the profile associated with the given // localpart. Returns an error if something went wrong with the SQL query func (d *Database) SetDisplayName( ctx context.Context, localpart string, displayName string, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.Profiles.SetDisplayName(ctx, txn, localpart, displayName) +) (profile *authtypes.Profile, changed bool, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName) + return err }) + return } // SetPassword sets the account password to the given hash. diff --git a/userapi/storage/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go index 267daf044..b6130a1e3 100644 --- a/userapi/storage/sqlite3/profile_table.go +++ b/userapi/storage/sqlite3/profile_table.go @@ -44,10 +44,12 @@ const selectProfileByLocalpartSQL = "" + "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" const setAvatarURLSQL = "" + - "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" + "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" + + " RETURNING display_name" const setDisplayNameSQL = "" + - "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" + "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" + + " RETURNING avatar_url" const selectProfilesBySearchSQL = "" + "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" @@ -102,18 +104,40 @@ func (s *profilesStatements) SelectProfileByLocalpart( func (s *profilesStatements) SetAvatarURL( ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, -) (err error) { +) (*authtypes.Profile, bool, error) { + profile := &authtypes.Profile{ + Localpart: localpart, + AvatarURL: avatarURL, + } + old, err := s.SelectProfileByLocalpart(ctx, localpart) + if err != nil { + return old, false, err + } + if old.AvatarURL == avatarURL { + return old, false, nil + } stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) - _, err = stmt.ExecContext(ctx, avatarURL, localpart) - return + err = stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName) + return profile, true, err } func (s *profilesStatements) SetDisplayName( ctx context.Context, txn *sql.Tx, localpart string, displayName string, -) (err error) { +) (*authtypes.Profile, bool, error) { + profile := &authtypes.Profile{ + Localpart: localpart, + DisplayName: displayName, + } + old, err := s.SelectProfileByLocalpart(ctx, localpart) + if err != nil { + return old, false, err + } + if old.DisplayName == displayName { + return old, false, nil + } stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) - _, err = stmt.ExecContext(ctx, displayName, localpart) - return + err = stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL) + return profile, true, err } func (s *profilesStatements) SelectProfilesBySearch( diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 8e5b32b6a..354f085fc 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -382,15 +382,23 @@ func Test_Profile(t *testing.T) { // set avatar & displayname wantProfile.DisplayName = "Alice" - wantProfile.AvatarURL = "mxc://aliceAvatar" - err = db.SetDisplayName(ctx, aliceLocalpart, "Alice") - assert.NoError(t, err, "unable to set displayname") - err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") - assert.NoError(t, err, "unable to set avatar url") - // verify profile - gotProfile, err = db.GetProfileByLocalpart(ctx, aliceLocalpart) - assert.NoError(t, err, "unable to get profile by localpart") + gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, "Alice") assert.Equal(t, wantProfile, gotProfile) + assert.NoError(t, err, "unable to set displayname") + assert.True(t, changed) + + wantProfile.AvatarURL = "mxc://aliceAvatar" + gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") + assert.NoError(t, err, "unable to set avatar url") + assert.Equal(t, wantProfile, gotProfile) + assert.True(t, changed) + + // Setting the same avatar again doesn't change anything + wantProfile.AvatarURL = "mxc://aliceAvatar" + gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") + assert.NoError(t, err, "unable to set avatar url") + assert.Equal(t, wantProfile, gotProfile) + assert.False(t, changed) // search profiles searchRes, err := db.SearchProfiles(ctx, "Alice", 2) diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 9e69954e7..5e1dd0971 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -86,8 +86,8 @@ type OpenIDTable interface { type ProfileTable interface { InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) - SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (err error) - SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (err error) + SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, bool, error) + SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, bool, error) SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 4417f4dc0..aaa93f45b 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -23,13 +23,14 @@ import ( "time" "github.com/gorilla/mux" + "github.com/matrix-org/gomatrixserverlib" + "golang.org/x/crypto/bcrypt" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/userapi/inthttp" - "github.com/matrix-org/gomatrixserverlib" - "golang.org/x/crypto/bcrypt" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" @@ -83,10 +84,10 @@ func TestQueryProfile(t *testing.T) { if err != nil { t.Fatalf("failed to make account: %s", err) } - if err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil { + if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil { t.Fatalf("failed to set avatar url: %s", err) } - if err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil { + if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil { t.Fatalf("failed to set display name: %s", err) }