diff --git a/CHANGES.md b/CHANGES.md index eea2c3c7c..1ed87824a 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,23 @@ # Changelog +## Dendrite 0.10.4 (2022-10-21) + +### Features + +* Various tables belonging to the user API will be renamed so that they are namespaced with the `userapi_` prefix + * Note that, after upgrading to this version, you should not revert to an older version of Dendrite as the database changes **will not** be reverted automatically +* The backoff and retry behaviour in the federation API has been refactored and improved + +### Fixes + +* Private read receipt support is now advertised in the client `/versions` endpoint +* Private read receipts will now clear notification counts properly +* A bug where a false `leave` membership transition was inserted into the timeline after accepting an invite has been fixed +* Some panics caused by concurrent map writes in the key server have been fixed +* The sync API now calculates membership transitions from state deltas more accurately +* Transaction IDs are now scoped to endpoints, which should fix some bugs where transaction ID reuse could cause nonsensical cached responses from some endpoints +* The length of the `type`, `sender`, `state_key` and `room_id` fields in events are now verified by number of bytes rather than codepoints after a spec clarification, reverting a change made in Dendrite 0.9.6 + ## Dendrite 0.10.3 (2022-10-14) ### Features diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index af4702fc6..4e5531b22 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -19,6 +19,8 @@ import ( "net/http" "time" + "github.com/matrix-org/gomatrixserverlib" + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" @@ -27,7 +29,6 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrix" "github.com/matrix-org/util" @@ -120,20 +121,6 @@ func SetAvatarURL( } } - res := &userapi.QueryProfileResponse{} - err = profileAPI.QueryProfile(req.Context(), &userapi.QueryProfileRequest{ - UserID: userID, - }, res) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("profileAPI.QueryProfile failed") - return jsonerror.InternalServerError() - } - oldProfile := &authtypes.Profile{ - Localpart: localpart, - DisplayName: res.DisplayName, - AvatarURL: res.AvatarURL, - } - setRes := &userapi.PerformSetAvatarURLResponse{} if err = profileAPI.SetAvatarURL(req.Context(), &userapi.PerformSetAvatarURLRequest{ Localpart: localpart, @@ -142,41 +129,17 @@ func SetAvatarURL( util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetAvatarURL failed") return jsonerror.InternalServerError() } - - var roomsRes api.QueryRoomsForUserResponse - err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{ - UserID: device.UserID, - WantMembership: "join", - }, &roomsRes) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed") - return jsonerror.InternalServerError() - } - - newProfile := authtypes.Profile{ - Localpart: localpart, - DisplayName: oldProfile.DisplayName, - AvatarURL: r.AvatarURL, - } - - events, err := buildMembershipEvents( - req.Context(), roomsRes.RoomIDs, newProfile, userID, cfg, evTime, rsAPI, - ) - switch e := err.(type) { - case nil: - case gomatrixserverlib.BadJSONError: + // No need to build new membership events, since nothing changed + if !setRes.Changed { return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(e.Error()), + Code: http.StatusOK, + JSON: struct{}{}, } - default: - util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvents failed") - return jsonerror.InternalServerError() } - if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") - return jsonerror.InternalServerError() + response, err := updateProfile(req.Context(), rsAPI, device, setRes.Profile, userID, cfg, evTime) + if err != nil { + return response } return util.JSONResponse{ @@ -249,47 +212,51 @@ func SetDisplayName( } } - pRes := &userapi.QueryProfileResponse{} - err = profileAPI.QueryProfile(req.Context(), &userapi.QueryProfileRequest{ - UserID: userID, - }, pRes) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("profileAPI.QueryProfile failed") - return jsonerror.InternalServerError() - } - oldProfile := &authtypes.Profile{ - Localpart: localpart, - DisplayName: pRes.DisplayName, - AvatarURL: pRes.AvatarURL, - } - + profileRes := &userapi.PerformUpdateDisplayNameResponse{} err = profileAPI.SetDisplayName(req.Context(), &userapi.PerformUpdateDisplayNameRequest{ Localpart: localpart, DisplayName: r.DisplayName, - }, &struct{}{}) + }, profileRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed") return jsonerror.InternalServerError() } + // No need to build new membership events, since nothing changed + if !profileRes.Changed { + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } + } + response, err := updateProfile(req.Context(), rsAPI, device, profileRes.Profile, userID, cfg, evTime) + if err != nil { + return response + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} + +func updateProfile( + ctx context.Context, rsAPI api.ClientRoomserverAPI, device *userapi.Device, + profile *authtypes.Profile, + userID string, cfg *config.ClientAPI, evTime time.Time, +) (util.JSONResponse, error) { var res api.QueryRoomsForUserResponse - err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{ + err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ UserID: device.UserID, WantMembership: "join", }, &res) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed") - return jsonerror.InternalServerError() - } - - newProfile := authtypes.Profile{ - Localpart: localpart, - DisplayName: r.DisplayName, - AvatarURL: oldProfile.AvatarURL, + util.GetLogger(ctx).WithError(err).Error("QueryRoomsForUser failed") + return jsonerror.InternalServerError(), err } events, err := buildMembershipEvents( - req.Context(), res.RoomIDs, newProfile, userID, cfg, evTime, rsAPI, + ctx, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI, ) switch e := err.(type) { case nil: @@ -297,21 +264,17 @@ func SetDisplayName( return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(e.Error()), - } + }, e default: - util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvents failed") - return jsonerror.InternalServerError() + util.GetLogger(ctx).WithError(err).Error("buildMembershipEvents failed") + return jsonerror.InternalServerError(), e } - if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") - return jsonerror.InternalServerError() - } - - return util.JSONResponse{ - Code: http.StatusOK, - JSON: struct{}{}, + if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil { + util.GetLogger(ctx).WithError(err).Error("SendEvents failed") + return jsonerror.InternalServerError(), err } + return util.JSONResponse{}, nil } // getProfile gets the full profile of a user by querying the database or a diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index 27f0ba5d0..a0f3b1152 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -19,6 +19,9 @@ import ( "net/http" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" @@ -26,8 +29,6 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" ) type redactionContent struct { @@ -51,7 +52,7 @@ func SendRedaction( if txnID != nil { // Try to fetch response from transactionsCache - if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok { + if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok { return *res } } @@ -144,7 +145,7 @@ func SendRedaction( // Add response to transactionsCache if txnID != nil { - txnCache.AddTransaction(device.AccessToken, *txnID, &res) + txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res) } return res diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 2a0d8ea2e..d1b304fd9 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -72,6 +72,7 @@ func Setup( unstableFeatures := map[string]bool{ "org.matrix.e2e_cross_signing": true, + "org.matrix.msc2285.stable": true, } for _, msc := range cfg.MSCs.MSCs { unstableFeatures["org.matrix."+msc] = true @@ -179,7 +180,7 @@ func Setup( // server notifications if cfg.Matrix.ServerNotices.Enabled { logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice") - serverNotificationSender, err := getSenderDevice(context.Background(), userAPI, cfg) + serverNotificationSender, err := getSenderDevice(context.Background(), rsAPI, userAPI, cfg) if err != nil { logrus.WithError(err).Fatal("unable to get account for sending sending server notices") } diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 85f1053f3..114e9088d 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -86,7 +86,7 @@ func SendEvent( if txnID != nil { // Try to fetch response from transactionsCache - if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok { + if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok { return *res } } @@ -206,7 +206,7 @@ func SendEvent( } // Add response to transactionsCache if txnID != nil { - txnCache.AddTransaction(device.AccessToken, *txnID, &res) + txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res) } // Take a note of how long it took to generate the event vs submit diff --git a/clientapi/routing/sendtodevice.go b/clientapi/routing/sendtodevice.go index 4a5f08883..0c0227937 100644 --- a/clientapi/routing/sendtodevice.go +++ b/clientapi/routing/sendtodevice.go @@ -16,12 +16,13 @@ import ( "encoding/json" "net/http" + "github.com/matrix-org/util" + "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/internal/transactions" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/util" ) // SendToDevice handles PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId} @@ -33,7 +34,7 @@ func SendToDevice( eventType string, txnID *string, ) util.JSONResponse { if txnID != nil { - if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok { + if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok { return *res } } @@ -63,7 +64,7 @@ func SendToDevice( } if txnID != nil { - txnCache.AddTransaction(device.AccessToken, *txnID, &res) + txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res) } return res diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go index 9edeed2f7..a6a78061d 100644 --- a/clientapi/routing/server_notices.go +++ b/clientapi/routing/server_notices.go @@ -21,7 +21,6 @@ import ( "net/http" "time" - "github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/tokens" @@ -29,6 +28,8 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/roomserver/version" + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" @@ -73,7 +74,7 @@ func SendServerNotice( if txnID != nil { // Try to fetch response from transactionsCache - if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok { + if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok { return *res } } @@ -251,7 +252,7 @@ func SendServerNotice( } // Add response to transactionsCache if txnID != nil { - txnCache.AddTransaction(device.AccessToken, *txnID, &res) + txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res) } // Take a note of how long it took to generate the event vs submit @@ -276,6 +277,7 @@ func (r sendServerNoticeRequest) valid() (ok bool) { // It returns an userapi.Device, which is used for building the event func getSenderDevice( ctx context.Context, + rsAPI api.ClientRoomserverAPI, userAPI userapi.ClientUserAPI, cfg *config.ClientAPI, ) (*userapi.Device, error) { @@ -290,16 +292,32 @@ func getSenderDevice( return nil, err } - // set the avatarurl for the user - res := &userapi.PerformSetAvatarURLResponse{} + // Set the avatarurl for the user + avatarRes := &userapi.PerformSetAvatarURLResponse{} if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{ Localpart: cfg.Matrix.ServerNotices.LocalPart, AvatarURL: cfg.Matrix.ServerNotices.AvatarURL, - }, res); err != nil { + }, avatarRes); err != nil { util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed") return nil, err } + profile := avatarRes.Profile + + // Set the displayname for the user + displayNameRes := &userapi.PerformUpdateDisplayNameResponse{} + if err = userAPI.SetDisplayName(ctx, &userapi.PerformUpdateDisplayNameRequest{ + Localpart: cfg.Matrix.ServerNotices.LocalPart, + DisplayName: cfg.Matrix.ServerNotices.DisplayName, + }, displayNameRes); err != nil { + util.GetLogger(ctx).WithError(err).Error("userAPI.SetDisplayName failed") + return nil, err + } + + if displayNameRes.Changed { + profile.DisplayName = cfg.Matrix.ServerNotices.DisplayName + } + // Check if we got existing devices deviceRes := &userapi.QueryDevicesResponse{} err = userAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{ @@ -309,7 +327,15 @@ func getSenderDevice( return nil, err } + // We've got an existing account, return the first device of it if len(deviceRes.Devices) > 0 { + // If there were changes to the profile, create a new membership event + if displayNameRes.Changed || avatarRes.Changed { + _, err = updateProfile(ctx, rsAPI, &deviceRes.Devices[0], profile, accRes.Account.UserID, cfg, time.Now()) + if err != nil { + return nil, err + } + } return &deviceRes.Devices[0], nil } diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index 52301415f..c8e239f29 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -179,7 +179,10 @@ func sharedSecretRegister(sharedSecret, serverURL, localpart, password string, a body, _ = io.ReadAll(regResp.Body) return "", fmt.Errorf(gjson.GetBytes(body, "error").Str) } - r, _ := io.ReadAll(regResp.Body) + r, err := io.ReadAll(regResp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response body (HTTP %d): %w", regResp.StatusCode, err) + } return gjson.GetBytes(r, "access_token").Str, nil } diff --git a/docs/Gemfile.lock b/docs/Gemfile.lock index bc73df728..c7ba43711 100644 --- a/docs/Gemfile.lock +++ b/docs/Gemfile.lock @@ -231,9 +231,9 @@ GEM jekyll-seo-tag (~> 2.1) minitest (5.15.0) multipart-post (2.1.1) - nokogiri (1.13.6-arm64-darwin) + nokogiri (1.13.9-arm64-darwin) racc (~> 1.4) - nokogiri (1.13.6-x86_64-linux) + nokogiri (1.13.9-x86_64-linux) racc (~> 1.4) octokit (4.22.0) faraday (>= 0.9) diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 4a13c9d9b..f6dace702 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -116,17 +116,14 @@ func NewInternalAPI( _ = federationDB.RemoveAllServersFromBlacklist() } - stats := &statistics.Statistics{ - DB: federationDB, - FailuresUntilBlacklist: cfg.FederationMaxRetries, - } + stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1) js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) queues := queue.NewOutgoingQueues( federationDB, base.ProcessContext, cfg.Matrix.DisableFederation, - cfg.Matrix.ServerName, federation, rsAPI, stats, + cfg.Matrix.ServerName, federation, rsAPI, &stats, &queue.SigningInfo{ KeyID: cfg.Matrix.KeyID, PrivateKey: cfg.Matrix.PrivateKey, @@ -183,5 +180,5 @@ func NewInternalAPI( } time.AfterFunc(time.Minute, cleanExpiredEDUs) - return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, stats, caches, queues, keyRing) + return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, &stats, caches, queues, keyRing) } diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index 5cb8cae1f..1b7670e9a 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -21,21 +21,22 @@ import ( "sync" "time" + "github.com/matrix-org/gomatrix" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" + "go.uber.org/atomic" + fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/gomatrix" - "github.com/matrix-org/gomatrixserverlib" - "github.com/sirupsen/logrus" - "go.uber.org/atomic" ) const ( maxPDUsPerTransaction = 50 - maxEDUsPerTransaction = 50 + maxEDUsPerTransaction = 100 maxPDUsInMemory = 128 maxEDUsInMemory = 128 queueIdleTimeout = time.Second * 30 @@ -64,7 +65,6 @@ type destinationQueue struct { pendingPDUs []*queuedPDU // PDUs waiting to be sent pendingEDUs []*queuedEDU // EDUs waiting to be sent pendingMutex sync.RWMutex // protects pendingPDUs and pendingEDUs - interruptBackoff chan bool // interrupts backoff } // Send event adds the event to the pending queue for the destination. @@ -75,39 +75,22 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination) return } - // Create a database entry that associates the given PDU NID with - // this destination queue. We'll then be able to retrieve the PDU - // later. - if err := oq.db.AssociatePDUWithDestination( - oq.process.Context(), - "", // TODO: remove this, as we don't need to persist the transaction ID - oq.destination, // the destination server name - receipt, // NIDs from federationapi_queue_json table - ); err != nil { - logrus.WithError(err).Errorf("failed to associate PDU %q with destination %q", event.EventID(), oq.destination) - return + + // If there's room in memory to hold the event then add it to the + // list. + oq.pendingMutex.Lock() + if len(oq.pendingPDUs) < maxPDUsInMemory { + oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{ + pdu: event, + receipt: receipt, + }) + } else { + oq.overflowed.Store(true) } - // Check if the destination is blacklisted. If it isn't then wake - // up the queue. - if !oq.statistics.Blacklisted() { - // If there's room in memory to hold the event then add it to the - // list. - oq.pendingMutex.Lock() - if len(oq.pendingPDUs) < maxPDUsInMemory { - oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{ - pdu: event, - receipt: receipt, - }) - } else { - oq.overflowed.Store(true) - } - oq.pendingMutex.Unlock() - // Wake up the queue if it's asleep. - oq.wakeQueueIfNeeded() - select { - case oq.notify <- struct{}{}: - default: - } + oq.pendingMutex.Unlock() + + if !oq.backingOff.Load() { + oq.wakeQueueAndNotify() } } @@ -119,40 +102,47 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share logrus.Errorf("attempt to send nil EDU with destination %q", oq.destination) return } - // Create a database entry that associates the given PDU NID with - // this destination queue. We'll then be able to retrieve the PDU - // later. - if err := oq.db.AssociateEDUWithDestination( - oq.process.Context(), - oq.destination, // the destination server name - receipt, // NIDs from federationapi_queue_json table - event.Type, - nil, // this will use the default expireEDUTypes map - ); err != nil { - logrus.WithError(err).Errorf("failed to associate EDU with destination %q", oq.destination) - return + + // If there's room in memory to hold the event then add it to the + // list. + oq.pendingMutex.Lock() + if len(oq.pendingEDUs) < maxEDUsInMemory { + oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{ + edu: event, + receipt: receipt, + }) + } else { + oq.overflowed.Store(true) } - // Check if the destination is blacklisted. If it isn't then wake - // up the queue. - if !oq.statistics.Blacklisted() { - // If there's room in memory to hold the event then add it to the - // list. - oq.pendingMutex.Lock() - if len(oq.pendingEDUs) < maxEDUsInMemory { - oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{ - edu: event, - receipt: receipt, - }) - } else { - oq.overflowed.Store(true) - } - oq.pendingMutex.Unlock() - // Wake up the queue if it's asleep. - oq.wakeQueueIfNeeded() - select { - case oq.notify <- struct{}{}: - default: - } + oq.pendingMutex.Unlock() + + if !oq.backingOff.Load() { + oq.wakeQueueAndNotify() + } +} + +// handleBackoffNotifier is registered as the backoff notification +// callback with Statistics. It will wakeup and notify the queue +// if the queue is currently backing off. +func (oq *destinationQueue) handleBackoffNotifier() { + // Only wake up the queue if it is backing off. + // Otherwise there is no pending work for the queue to handle + // so waking the queue would be a waste of resources. + if oq.backingOff.Load() { + oq.wakeQueueAndNotify() + } +} + +// wakeQueueAndNotify ensures the destination queue is running and notifies it +// that there is pending work. +func (oq *destinationQueue) wakeQueueAndNotify() { + // Wake up the queue if it's asleep. + oq.wakeQueueIfNeeded() + + // Notify the queue that there are events ready to send. + select { + case oq.notify <- struct{}{}: + default: } } @@ -161,10 +151,11 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share // then we will interrupt the backoff, causing any federation // requests to retry. func (oq *destinationQueue) wakeQueueIfNeeded() { - // If we are backing off then interrupt the backoff. + // Clear the backingOff flag and update the backoff metrics if it was set. if oq.backingOff.CompareAndSwap(true, false) { - oq.interruptBackoff <- true + destinationQueueBackingOff.Dec() } + // If we aren't running then wake up the queue. if !oq.running.Load() { // Start the queue. @@ -196,38 +187,54 @@ func (oq *destinationQueue) getPendingFromDatabase() { gotEDUs[edu.receipt.String()] = struct{}{} } + overflowed := false if pduCapacity := maxPDUsInMemory - len(oq.pendingPDUs); pduCapacity > 0 { // We have room in memory for some PDUs - let's request no more than that. - if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, pduCapacity); err == nil { + if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, maxPDUsInMemory); err == nil { + if len(pdus) == maxPDUsInMemory { + overflowed = true + } for receipt, pdu := range pdus { if _, ok := gotPDUs[receipt.String()]; ok { continue } oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{receipt, pdu}) retrieved = true + if len(oq.pendingPDUs) == maxPDUsInMemory { + break + } } } else { logrus.WithError(err).Errorf("Failed to get pending PDUs for %q", oq.destination) } } + if eduCapacity := maxEDUsInMemory - len(oq.pendingEDUs); eduCapacity > 0 { // We have room in memory for some EDUs - let's request no more than that. - if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, eduCapacity); err == nil { + if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, maxEDUsInMemory); err == nil { + if len(edus) == maxEDUsInMemory { + overflowed = true + } for receipt, edu := range edus { if _, ok := gotEDUs[receipt.String()]; ok { continue } oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{receipt, edu}) retrieved = true + if len(oq.pendingEDUs) == maxEDUsInMemory { + break + } } } else { logrus.WithError(err).Errorf("Failed to get pending EDUs for %q", oq.destination) } } + // If we've retrieved all of the events from the database with room to spare // in memory then we'll no longer consider this queue to be overflowed. - if len(oq.pendingPDUs) < maxPDUsInMemory && len(oq.pendingEDUs) < maxEDUsInMemory { + if !overflowed { oq.overflowed.Store(false) + } else { } // If we've retrieved some events then notify the destination queue goroutine. if retrieved { @@ -238,6 +245,24 @@ func (oq *destinationQueue) getPendingFromDatabase() { } } +// checkNotificationsOnClose checks for any remaining notifications +// and starts a new backgroundSend goroutine if any exist. +func (oq *destinationQueue) checkNotificationsOnClose() { + // NOTE : If we are stopping the queue due to blacklist then it + // doesn't matter if we have been notified of new work since + // this queue instance will be deleted anyway. + if !oq.statistics.Blacklisted() { + select { + case <-oq.notify: + // We received a new notification in between the + // idle timeout firing and stopping the goroutine. + // Immediately restart the queue. + oq.wakeQueueAndNotify() + default: + } + } +} + // backgroundSend is the worker goroutine for sending events. func (oq *destinationQueue) backgroundSend() { // Check if a worker is already running, and if it isn't, then @@ -245,10 +270,17 @@ func (oq *destinationQueue) backgroundSend() { if !oq.running.CompareAndSwap(false, true) { return } + + // Register queue cleanup functions. + // NOTE : The ordering here is very intentional. + defer oq.checkNotificationsOnClose() + defer oq.running.Store(false) + destinationQueueRunning.Inc() defer destinationQueueRunning.Dec() - defer oq.queues.clearQueue(oq) - defer oq.running.Store(false) + + idleTimeout := time.NewTimer(queueIdleTimeout) + defer idleTimeout.Stop() // Mark the queue as overflowed, so we will consult the database // to see if there's anything new to send. @@ -261,59 +293,33 @@ func (oq *destinationQueue) backgroundSend() { oq.getPendingFromDatabase() } + // Reset the queue idle timeout. + if !idleTimeout.Stop() { + select { + case <-idleTimeout.C: + default: + } + } + idleTimeout.Reset(queueIdleTimeout) + // If we have nothing to do then wait either for incoming events, or // until we hit an idle timeout. select { case <-oq.notify: // There's work to do, either because getPendingFromDatabase - // told us there is, or because a new event has come in via - // sendEvent/sendEDU. - case <-time.After(queueIdleTimeout): + // told us there is, a new event has come in via sendEvent/sendEDU, + // or we are backing off and it is time to retry. + case <-idleTimeout.C: // The worker is idle so stop the goroutine. It'll get // restarted automatically the next time we have an event to // send. return case <-oq.process.Context().Done(): // The parent process is shutting down, so stop. + oq.statistics.ClearBackoff() return } - // If we are backing off this server then wait for the - // backoff duration to complete first, or until explicitly - // told to retry. - until, blacklisted := oq.statistics.BackoffInfo() - if blacklisted { - // It's been suggested that we should give up because the backoff - // has exceeded a maximum allowable value. Clean up the in-memory - // buffers at this point. The PDU clean-up is already on a defer. - logrus.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination) - oq.pendingMutex.Lock() - for i := range oq.pendingPDUs { - oq.pendingPDUs[i] = nil - } - for i := range oq.pendingEDUs { - oq.pendingEDUs[i] = nil - } - oq.pendingPDUs = nil - oq.pendingEDUs = nil - oq.pendingMutex.Unlock() - return - } - if until != nil && until.After(time.Now()) { - // We haven't backed off yet, so wait for the suggested amount of - // time. - duration := time.Until(*until) - logrus.Debugf("Backing off %q for %s", oq.destination, duration) - oq.backingOff.Store(true) - destinationQueueBackingOff.Inc() - select { - case <-time.After(duration): - case <-oq.interruptBackoff: - } - destinationQueueBackingOff.Dec() - oq.backingOff.Store(false) - } - // Work out which PDUs/EDUs to include in the next transaction. oq.pendingMutex.RLock() pduCount := len(oq.pendingPDUs) @@ -328,99 +334,52 @@ func (oq *destinationQueue) backgroundSend() { toSendEDUs := oq.pendingEDUs[:eduCount] oq.pendingMutex.RUnlock() + // If we didn't get anything from the database and there are no + // pending EDUs then there's nothing to do - stop here. + if pduCount == 0 && eduCount == 0 { + continue + } + // If we have pending PDUs or EDUs then construct a transaction. // Try sending the next transaction and see what happens. - transaction, pc, ec, terr := oq.nextTransaction(toSendPDUs, toSendEDUs) + terr := oq.nextTransaction(toSendPDUs, toSendEDUs) if terr != nil { // We failed to send the transaction. Mark it as a failure. - oq.statistics.Failure() - - } else if transaction { - // If we successfully sent the transaction then clear out - // the pending events and EDUs, and wipe our transaction ID. - oq.statistics.Success() - oq.pendingMutex.Lock() - for i := range oq.pendingPDUs[:pc] { - oq.pendingPDUs[i] = nil + _, blacklisted := oq.statistics.Failure() + if !blacklisted { + // Register the backoff state and exit the goroutine. + // It'll get restarted automatically when the backoff + // completes. + oq.backingOff.Store(true) + destinationQueueBackingOff.Inc() + return + } else { + // Immediately trigger the blacklist logic. + oq.blacklistDestination() + return } - for i := range oq.pendingEDUs[:ec] { - oq.pendingEDUs[i] = nil - } - oq.pendingPDUs = oq.pendingPDUs[pc:] - oq.pendingEDUs = oq.pendingEDUs[ec:] - oq.pendingMutex.Unlock() + } else { + oq.handleTransactionSuccess(pduCount, eduCount) } } } // nextTransaction creates a new transaction from the pending event -// queue and sends it. Returns true if a transaction was sent or -// false otherwise. +// queue and sends it. +// Returns an error if the transaction wasn't sent. func (oq *destinationQueue) nextTransaction( pdus []*queuedPDU, edus []*queuedEDU, -) (bool, int, int, error) { - // If there's no projected transaction ID then generate one. If - // the transaction succeeds then we'll set it back to "" so that - // we generate a new one next time. If it fails, we'll preserve - // it so that we retry with the same transaction ID. - oq.transactionIDMutex.Lock() - if oq.transactionID == "" { - now := gomatrixserverlib.AsTimestamp(time.Now()) - oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount())) - } - oq.transactionIDMutex.Unlock() - +) error { // Create the transaction. - t := gomatrixserverlib.Transaction{ - PDUs: []json.RawMessage{}, - EDUs: []gomatrixserverlib.EDU{}, - } - t.Origin = oq.origin - t.Destination = oq.destination - t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now()) - t.TransactionID = oq.transactionID - - // If we didn't get anything from the database and there are no - // pending EDUs then there's nothing to do - stop here. - if len(pdus) == 0 && len(edus) == 0 { - return false, 0, 0, nil - } - - var pduReceipts []*shared.Receipt - var eduReceipts []*shared.Receipt - - // Go through PDUs that we retrieved from the database, if any, - // and add them into the transaction. - for _, pdu := range pdus { - if pdu == nil || pdu.pdu == nil { - continue - } - // Append the JSON of the event, since this is a json.RawMessage type in the - // gomatrixserverlib.Transaction struct - t.PDUs = append(t.PDUs, pdu.pdu.JSON()) - pduReceipts = append(pduReceipts, pdu.receipt) - } - - // Do the same for pending EDUS in the queue. - for _, edu := range edus { - if edu == nil || edu.edu == nil { - continue - } - t.EDUs = append(t.EDUs, *edu.edu) - eduReceipts = append(eduReceipts, edu.receipt) - } - + t, pduReceipts, eduReceipts := oq.createTransaction(pdus, edus) logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs)) // Try to send the transaction to the destination server. - // TODO: we should check for 500-ish fails vs 400-ish here, - // since we shouldn't queue things indefinitely in response - // to a 400-ish error ctx, cancel := context.WithTimeout(oq.process.Context(), time.Minute*5) defer cancel() _, err := oq.client.SendTransaction(ctx, t) - switch err.(type) { + switch errResponse := err.(type) { case nil: // Clean up the transaction in the database. if pduReceipts != nil { @@ -439,16 +398,129 @@ func (oq *destinationQueue) nextTransaction( oq.transactionIDMutex.Lock() oq.transactionID = "" oq.transactionIDMutex.Unlock() - return true, len(t.PDUs), len(t.EDUs), nil + return nil case gomatrix.HTTPError: // Report that we failed to send the transaction and we // will retry again, subject to backoff. - return false, 0, 0, err + + // TODO: we should check for 500-ish fails vs 400-ish here, + // since we shouldn't queue things indefinitely in response + // to a 400-ish error + code := errResponse.Code + logrus.Debug("Transaction failed with HTTP", code) + return err default: logrus.WithFields(logrus.Fields{ "destination": oq.destination, logrus.ErrorKey: err, }).Debugf("Failed to send transaction %q", t.TransactionID) - return false, 0, 0, err + return err + } +} + +// createTransaction generates a gomatrixserverlib.Transaction from the provided pdus and edus. +// It also returns the associated event receipts so they can be cleaned from the database in +// the case of a successful transaction. +func (oq *destinationQueue) createTransaction( + pdus []*queuedPDU, + edus []*queuedEDU, +) (gomatrixserverlib.Transaction, []*shared.Receipt, []*shared.Receipt) { + // If there's no projected transaction ID then generate one. If + // the transaction succeeds then we'll set it back to "" so that + // we generate a new one next time. If it fails, we'll preserve + // it so that we retry with the same transaction ID. + oq.transactionIDMutex.Lock() + if oq.transactionID == "" { + now := gomatrixserverlib.AsTimestamp(time.Now()) + oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount())) + } + oq.transactionIDMutex.Unlock() + + t := gomatrixserverlib.Transaction{ + PDUs: []json.RawMessage{}, + EDUs: []gomatrixserverlib.EDU{}, + } + t.Origin = oq.origin + t.Destination = oq.destination + t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now()) + t.TransactionID = oq.transactionID + + var pduReceipts []*shared.Receipt + var eduReceipts []*shared.Receipt + + // Go through PDUs that we retrieved from the database, if any, + // and add them into the transaction. + for _, pdu := range pdus { + // These should never be nil. + if pdu == nil || pdu.pdu == nil { + continue + } + // Append the JSON of the event, since this is a json.RawMessage type in the + // gomatrixserverlib.Transaction struct + t.PDUs = append(t.PDUs, pdu.pdu.JSON()) + pduReceipts = append(pduReceipts, pdu.receipt) + } + + // Do the same for pending EDUS in the queue. + for _, edu := range edus { + // These should never be nil. + if edu == nil || edu.edu == nil { + continue + } + t.EDUs = append(t.EDUs, *edu.edu) + eduReceipts = append(eduReceipts, edu.receipt) + } + + return t, pduReceipts, eduReceipts +} + +// blacklistDestination removes all pending PDUs and EDUs that have been cached +// and deletes this queue. +func (oq *destinationQueue) blacklistDestination() { + // It's been suggested that we should give up because the backoff + // has exceeded a maximum allowable value. Clean up the in-memory + // buffers at this point. The PDU clean-up is already on a defer. + logrus.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination) + + oq.pendingMutex.Lock() + for i := range oq.pendingPDUs { + oq.pendingPDUs[i] = nil + } + for i := range oq.pendingEDUs { + oq.pendingEDUs[i] = nil + } + oq.pendingPDUs = nil + oq.pendingEDUs = nil + oq.pendingMutex.Unlock() + + // Delete this queue as no more messages will be sent to this + // destination until it is no longer blacklisted. + oq.statistics.AssignBackoffNotifier(nil) + oq.queues.clearQueue(oq) +} + +// handleTransactionSuccess updates the cached event queues as well as the success and +// backoff information for this server. +func (oq *destinationQueue) handleTransactionSuccess(pduCount int, eduCount int) { + // If we successfully sent the transaction then clear out + // the pending events and EDUs, and wipe our transaction ID. + oq.statistics.Success() + oq.pendingMutex.Lock() + defer oq.pendingMutex.Unlock() + + for i := range oq.pendingPDUs[:pduCount] { + oq.pendingPDUs[i] = nil + } + for i := range oq.pendingEDUs[:eduCount] { + oq.pendingEDUs[i] = nil + } + oq.pendingPDUs = oq.pendingPDUs[pduCount:] + oq.pendingEDUs = oq.pendingEDUs[eduCount:] + + if len(oq.pendingPDUs) > 0 || len(oq.pendingEDUs) > 0 { + select { + case oq.notify <- struct{}{}: + default: + } } } diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index 8245aa5bd..328334379 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -24,6 +24,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -162,23 +163,25 @@ func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *d if !ok || oq == nil { destinationQueueTotal.Inc() oq = &destinationQueue{ - queues: oqs, - db: oqs.db, - process: oqs.process, - rsAPI: oqs.rsAPI, - origin: oqs.origin, - destination: destination, - client: oqs.client, - statistics: oqs.statistics.ForServer(destination), - notify: make(chan struct{}, 1), - interruptBackoff: make(chan bool), - signing: oqs.signing, + queues: oqs, + db: oqs.db, + process: oqs.process, + rsAPI: oqs.rsAPI, + origin: oqs.origin, + destination: destination, + client: oqs.client, + statistics: oqs.statistics.ForServer(destination), + notify: make(chan struct{}, 1), + signing: oqs.signing, } + oq.statistics.AssignBackoffNotifier(oq.handleBackoffNotifier) oqs.queues[destination] = oq } return oq } +// clearQueue removes the queue for the provided destination from the +// set of destination queues. func (oqs *OutgoingQueues) clearQueue(oq *destinationQueue) { oqs.queuesMutex.Lock() defer oqs.queuesMutex.Unlock() @@ -245,11 +248,25 @@ func (oqs *OutgoingQueues) SendEvent( } for destination := range destmap { - if queue := oqs.getQueue(destination); queue != nil { + if queue := oqs.getQueue(destination); queue != nil && !queue.statistics.Blacklisted() { queue.sendEvent(ev, nid) + } else { + delete(destmap, destination) } } + // Create a database entry that associates the given PDU NID with + // this destinations queue. We'll then be able to retrieve the PDU + // later. + if err := oqs.db.AssociatePDUWithDestinations( + oqs.process.Context(), + destmap, + nid, // NIDs from federationapi_queue_json table + ); err != nil { + logrus.WithError(err).Errorf("failed to associate PDUs %q with destinations", nid) + return err + } + return nil } @@ -319,11 +336,27 @@ func (oqs *OutgoingQueues) SendEDU( } for destination := range destmap { - if queue := oqs.getQueue(destination); queue != nil { + if queue := oqs.getQueue(destination); queue != nil && !queue.statistics.Blacklisted() { queue.sendEDU(e, nid) + } else { + delete(destmap, destination) } } + // Create a database entry that associates the given PDU NID with + // this destination queue. We'll then be able to retrieve the PDU + // later. + if err := oqs.db.AssociateEDUWithDestinations( + oqs.process.Context(), + destmap, // the destination server name + nid, // NIDs from federationapi_queue_json table + e.Type, + nil, // this will use the default expireEDUTypes map + ); err != nil { + logrus.WithError(err).Errorf("failed to associate EDU with destinations") + return err + } + return nil } @@ -332,7 +365,9 @@ func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) { if oqs.disabled { return } + oqs.statistics.ForServer(srv).RemoveBlacklist() if queue := oqs.getQueue(srv); queue != nil { + queue.statistics.ClearBackoff() queue.wakeQueueIfNeeded() } } diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go new file mode 100644 index 000000000..a1b280103 --- /dev/null +++ b/federationapi/queue/queue_test.go @@ -0,0 +1,1060 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "testing" + "time" + + "go.uber.org/atomic" + "gotest.tools/v3/poll" + + "github.com/matrix-org/gomatrixserverlib" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + + "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/statistics" + "github.com/matrix-org/dendrite/federationapi/storage" + "github.com/matrix-org/dendrite/federationapi/storage/shared" + rsapi "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" +) + +func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *process.ProcessContext, func()) { + if realDatabase { + // Real Database/s + b, baseClose := testrig.CreateBaseDendrite(t, dbType) + connStr, dbClose := test.PrepareDBConnectionString(t, dbType) + db, err := storage.NewDatabase(b, &config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, b.Caches, b.Cfg.Global.ServerName) + if err != nil { + t.Fatalf("NewDatabase returned %s", err) + } + return db, b.ProcessContext, func() { + dbClose() + baseClose() + } + } else { + // Fake Database + db := createDatabase() + b := struct { + ProcessContext *process.ProcessContext + }{ProcessContext: process.NewProcessContext()} + return db, b.ProcessContext, func() {} + } +} + +func createDatabase() storage.Database { + return &fakeDatabase{ + pendingPDUServers: make(map[gomatrixserverlib.ServerName]struct{}), + pendingEDUServers: make(map[gomatrixserverlib.ServerName]struct{}), + blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}), + pendingPDUs: make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent), + pendingEDUs: make(map[*shared.Receipt]*gomatrixserverlib.EDU), + associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}), + associatedEDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}), + } +} + +type fakeDatabase struct { + storage.Database + dbMutex sync.Mutex + pendingPDUServers map[gomatrixserverlib.ServerName]struct{} + pendingEDUServers map[gomatrixserverlib.ServerName]struct{} + blacklistedServers map[gomatrixserverlib.ServerName]struct{} + pendingPDUs map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent + pendingEDUs map[*shared.Receipt]*gomatrixserverlib.EDU + associatedPDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} + associatedEDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} +} + +var nidMutex sync.Mutex +var nid = int64(0) + +func (d *fakeDatabase) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var event gomatrixserverlib.HeaderedEvent + if err := json.Unmarshal([]byte(js), &event); err == nil { + nidMutex.Lock() + defer nidMutex.Unlock() + nid++ + receipt := shared.NewReceipt(nid) + d.pendingPDUs[&receipt] = &event + return &receipt, nil + } + + var edu gomatrixserverlib.EDU + if err := json.Unmarshal([]byte(js), &edu); err == nil { + nidMutex.Lock() + defer nidMutex.Unlock() + nid++ + receipt := shared.NewReceipt(nid) + d.pendingEDUs[&receipt] = &edu + return &receipt, nil + } + + return nil, errors.New("Failed to determine type of json to store") +} + +func (d *fakeDatabase) GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + pduCount := 0 + pdus = make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent) + if receipts, ok := d.associatedPDUs[serverName]; ok { + for receipt := range receipts { + if event, ok := d.pendingPDUs[receipt]; ok { + pdus[receipt] = event + pduCount++ + if pduCount == limit { + break + } + } + } + } + return pdus, nil +} + +func (d *fakeDatabase) GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + eduCount := 0 + edus = make(map[*shared.Receipt]*gomatrixserverlib.EDU) + if receipts, ok := d.associatedEDUs[serverName]; ok { + for receipt := range receipts { + if event, ok := d.pendingEDUs[receipt]; ok { + edus[receipt] = event + eduCount++ + if eduCount == limit { + break + } + } + } + } + return edus, nil +} + +func (d *fakeDatabase) AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if _, ok := d.pendingPDUs[receipt]; ok { + for destination := range destinations { + if _, ok := d.associatedPDUs[destination]; !ok { + d.associatedPDUs[destination] = make(map[*shared.Receipt]struct{}) + } + d.associatedPDUs[destination][receipt] = struct{}{} + } + + return nil + } else { + return errors.New("PDU doesn't exist") + } +} + +func (d *fakeDatabase) AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if _, ok := d.pendingEDUs[receipt]; ok { + for destination := range destinations { + if _, ok := d.associatedEDUs[destination]; !ok { + d.associatedEDUs[destination] = make(map[*shared.Receipt]struct{}) + } + d.associatedEDUs[destination][receipt] = struct{}{} + } + + return nil + } else { + return errors.New("EDU doesn't exist") + } +} + +func (d *fakeDatabase) CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if pdus, ok := d.associatedPDUs[serverName]; ok { + for _, receipt := range receipts { + delete(pdus, receipt) + } + } + + return nil +} + +func (d *fakeDatabase) CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if edus, ok := d.associatedEDUs[serverName]; ok { + for _, receipt := range receipts { + delete(edus, receipt) + } + } + + return nil +} + +func (d *fakeDatabase) GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var count int64 + if pdus, ok := d.associatedPDUs[serverName]; ok { + count = int64(len(pdus)) + } + return count, nil +} + +func (d *fakeDatabase) GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var count int64 + if edus, ok := d.associatedEDUs[serverName]; ok { + count = int64(len(edus)) + } + return count, nil +} + +func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + servers := []gomatrixserverlib.ServerName{} + for server := range d.pendingPDUServers { + servers = append(servers, server) + } + return servers, nil +} + +func (d *fakeDatabase) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + servers := []gomatrixserverlib.ServerName{} + for server := range d.pendingEDUServers { + servers = append(servers, server) + } + return servers, nil +} + +func (d *fakeDatabase) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.blacklistedServers[serverName] = struct{}{} + return nil +} + +func (d *fakeDatabase) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + delete(d.blacklistedServers, serverName) + return nil +} + +func (d *fakeDatabase) RemoveAllServersFromBlacklist() error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.blacklistedServers = make(map[gomatrixserverlib.ServerName]struct{}) + return nil +} + +func (d *fakeDatabase) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + isBlacklisted := false + if _, ok := d.blacklistedServers[serverName]; ok { + isBlacklisted = true + } + + return isBlacklisted, nil +} + +type stubFederationRoomServerAPI struct { + rsapi.FederationRoomserverAPI +} + +func (r *stubFederationRoomServerAPI) QueryServerBannedFromRoom(ctx context.Context, req *rsapi.QueryServerBannedFromRoomRequest, res *rsapi.QueryServerBannedFromRoomResponse) error { + res.Banned = false + return nil +} + +type stubFederationClient struct { + api.FederationClient + shouldTxSucceed bool + txCount atomic.Uint32 +} + +func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) { + var result error + if !f.shouldTxSucceed { + result = fmt.Errorf("transaction failed") + } + + f.txCount.Add(1) + return gomatrixserverlib.RespSend{}, result +} + +func mustCreatePDU(t *testing.T) *gomatrixserverlib.HeaderedEvent { + t.Helper() + content := `{"type":"m.room.message"}` + ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(content), false, gomatrixserverlib.RoomVersionV10) + if err != nil { + t.Fatalf("failed to create event: %v", err) + } + return ev.Headered(gomatrixserverlib.RoomVersionV10) +} + +func mustCreateEDU(t *testing.T) *gomatrixserverlib.EDU { + t.Helper() + return &gomatrixserverlib.EDU{Type: gomatrixserverlib.MTyping} +} + +func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *stubFederationClient, *OutgoingQueues, *process.ProcessContext, func()) { + db, processContext, close := mustCreateFederationDatabase(t, dbType, realDatabase) + + fc := &stubFederationClient{ + shouldTxSucceed: shouldTxSucceed, + txCount: *atomic.NewUint32(0), + } + rs := &stubFederationRoomServerAPI{} + stats := statistics.NewStatistics(db, failuresUntilBlacklist) + signingInfo := &SigningInfo{ + KeyID: "ed21019:auto", + PrivateKey: test.PrivateKeyA, + ServerName: "localhost", + } + queues := NewOutgoingQueues(db, processContext, false, "localhost", fc, rs, &stats, signingInfo) + + return db, fc, queues, processContext, close +} + +func TestSendPDUOnSuccessRemovedFromDB(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == 1 { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendEDUOnSuccessRemovedFromDB(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == 1 { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendPDUOnFailStoredInDB(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + // Wait for 2 backoff attempts to ensure there was adequate time to attempt sending + if fc.txCount.Load() >= 2 { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + return poll.Success() + } + return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendEDUOnFailStoredInDB(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + // Wait for 2 backoff attempts to ensure there was adequate time to attempt sending + if fc.txCount.Load() >= 2 { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + return poll.Success() + } + return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendPDUAgainDoesntInterruptBackoff(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + // Wait for 2 backoff attempts to ensure there was adequate time to attempt sending + if fc.txCount.Load() >= 2 { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + return poll.Success() + } + return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + fc.shouldTxSucceed = true + ev = mustCreatePDU(t) + err = queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + pollEnd := time.Now().Add(1 * time.Second) + immediateCheck := func(log poll.LogT) poll.Result { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Error(fmt.Errorf("The backoff was interrupted early")) + } + if time.Now().After(pollEnd) { + // Allow more than enough time for the backoff to be interrupted before + // reporting that it wasn't. + return poll.Success() + } + return poll.Continue("waiting for events to be removed from database. Currently present PDU: %d", len(data)) + } + poll.WaitOn(t, immediateCheck, poll.WithTimeout(2*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendEDUAgainDoesntInterruptBackoff(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + // Wait for 2 backoff attempts to ensure there was adequate time to attempt sending + if fc.txCount.Load() >= 2 { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + return poll.Success() + } + return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + fc.shouldTxSucceed = true + ev = mustCreateEDU(t) + err = queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + pollEnd := time.Now().Add(1 * time.Second) + immediateCheck := func(log poll.LogT) poll.Result { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Error(fmt.Errorf("The backoff was interrupted early")) + } + if time.Now().After(pollEnd) { + // Allow more than enough time for the backoff to be interrupted before + // reporting that it wasn't. + return poll.Success() + } + return poll.Continue("waiting for events to be removed from database. Currently present EDU: %d", len(data)) + } + poll.WaitOn(t, immediateCheck, poll.WithTimeout(2*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendPDUMultipleFailuresBlacklisted(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendEDUMultipleFailuresBlacklisted(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendPDUBlacklistedWithPriorExternalFailure(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + queues.statistics.ForServer(destination).Failure() + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendEDUBlacklistedWithPriorExternalFailure(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + queues.statistics.ForServer(destination).Failure() + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestRetryServerSendsPDUSuccessfully(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(1) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + // NOTE : getQueue before sending event to ensure we grab the same queue reference + // before it is blacklisted and deleted. + dest := queues.getQueue(destination) + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + checkBlacklisted := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + if !dest.running.Load() { + return poll.Success() + } + return poll.Continue("waiting for queue to stop completely") + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + fc.shouldTxSucceed = true + db.RemoveServerFromBlacklist(destination) + queues.RetryServer(destination) + checkRetry := func(log poll.LogT) poll.Result { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database. Currently present PDU: %d", len(data)) + } + poll.WaitOn(t, checkRetry, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestRetryServerSendsEDUSuccessfully(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(1) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + // NOTE : getQueue before sending event to ensure we grab the same queue reference + // before it is blacklisted and deleted. + dest := queues.getQueue(destination) + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + checkBlacklisted := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + if !dest.running.Load() { + return poll.Success() + } + return poll.Continue("waiting for queue to stop completely") + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + fc.shouldTxSucceed = true + db.RemoveServerFromBlacklist(destination) + queues.RetryServer(destination) + checkRetry := func(log poll.LogT) poll.Result { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database. Currently present EDU: %d", len(data)) + } + poll.WaitOn(t, checkRetry, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendPDUBatches(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + + // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} + // Populate database with > maxPDUsPerTransaction + pduMultiplier := uint32(3) + for i := 0; i < maxPDUsPerTransaction*int(pduMultiplier); i++ { + ev := mustCreatePDU(t) + headeredJSON, _ := json.Marshal(ev) + nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON)) + err := db.AssociatePDUWithDestinations(pc.Context(), destinations, nid) + assert.NoError(t, err, "failed to associate PDU with destinations") + } + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == pduMultiplier+1 { // +1 for the extra SendEvent() + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for all events to be removed from database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for the right amount of send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + // }) +} + +func TestSendEDUBatches(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + + // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} + // Populate database with > maxEDUsPerTransaction + eduMultiplier := uint32(3) + for i := 0; i < maxEDUsPerTransaction*int(eduMultiplier); i++ { + ev := mustCreateEDU(t) + ephemeralJSON, _ := json.Marshal(ev) + nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON)) + err := db.AssociateEDUWithDestinations(pc.Context(), destinations, nid, ev.Type, nil) + assert.NoError(t, err, "failed to associate EDU with destinations") + } + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == eduMultiplier+1 { // +1 for the extra SendEvent() + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for all events to be removed from database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for the right amount of send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + // }) +} + +func TestSendPDUAndEDUBatches(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + + // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} + // Populate database with > maxEDUsPerTransaction + multiplier := uint32(3) + for i := 0; i < maxPDUsPerTransaction*int(multiplier)+1; i++ { + ev := mustCreatePDU(t) + headeredJSON, _ := json.Marshal(ev) + nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON)) + err := db.AssociatePDUWithDestinations(pc.Context(), destinations, nid) + assert.NoError(t, err, "failed to associate PDU with destinations") + } + + for i := 0; i < maxEDUsPerTransaction*int(multiplier); i++ { + ev := mustCreateEDU(t) + ephemeralJSON, _ := json.Marshal(ev) + nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON)) + err := db.AssociateEDUWithDestinations(pc.Context(), destinations, nid, ev.Type, nil) + assert.NoError(t, err, "failed to associate EDU with destinations") + } + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == multiplier+1 { // +1 for the extra SendEvent() + pduData, dbErrPDU := db.GetPendingPDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErrPDU) + eduData, dbErrEDU := db.GetPendingEDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErrEDU) + if len(pduData) == 0 && len(eduData) == 0 { + return poll.Success() + } + return poll.Continue("waiting for all events to be removed from database. Currently present PDU: %d EDU: %d", len(pduData), len(eduData)) + } + return poll.Continue("waiting for the right amount of send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + // }) +} + +func TestExternalFailureBackoffDoesntStartQueue(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + dest := queues.getQueue(destination) + queues.statistics.ForServer(destination).Failure() + destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} + ev := mustCreatePDU(t) + headeredJSON, _ := json.Marshal(ev) + nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON)) + err := db.AssociatePDUWithDestinations(pc.Context(), destinations, nid) + assert.NoError(t, err, "failed to associate PDU with destinations") + + pollEnd := time.Now().Add(3 * time.Second) + runningCheck := func(log poll.LogT) poll.Result { + if dest.running.Load() || fc.txCount.Load() > 0 { + return poll.Error(fmt.Errorf("The queue was started")) + } + if time.Now().After(pollEnd) { + // Allow more than enough time for the queue to be started in the case + // of backoff triggering it to start. + return poll.Success() + } + return poll.Continue("waiting to ensure queue doesn't start.") + } + poll.WaitOn(t, runningCheck, poll.WithTimeout(4*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { + // NOTE : Only one test case against real databases can be run at a time. + t.Parallel() + failuresUntilBlacklist := uint32(1) + destination := gomatrixserverlib.ServerName("remotehost") + destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, dbType, true) + // NOTE : These defers aren't called if go test is killed so the dbs may not get cleaned up. + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + // NOTE : getQueue before sending event to ensure we grab the same queue reference + // before it is blacklisted and deleted. + dest := queues.getQueue(destination) + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + // NOTE : The server can be blacklisted before this, so manually inject the event + // into the database. + edu := mustCreateEDU(t) + ephemeralJSON, _ := json.Marshal(edu) + nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON)) + err = db.AssociateEDUWithDestinations(pc.Context(), destinations, nid, edu.Type, nil) + assert.NoError(t, err, "failed to associate EDU with destinations") + + checkBlacklisted := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + pduData, dbErrPDU := db.GetPendingPDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErrPDU) + eduData, dbErrEDU := db.GetPendingEDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErrEDU) + if len(pduData) == 1 && len(eduData) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + if !dest.running.Load() { + return poll.Success() + } + return poll.Continue("waiting for queue to stop completely") + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for events to be added to database. Currently present PDU: %d EDU: %d", len(pduData), len(eduData)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond)) + + fc.shouldTxSucceed = true + db.RemoveServerFromBlacklist(destination) + queues.RetryServer(destination) + checkRetry := func(log poll.LogT) poll.Result { + pduData, dbErrPDU := db.GetPendingPDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErrPDU) + eduData, dbErrEDU := db.GetPendingEDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErrEDU) + if len(pduData) == 0 && len(eduData) == 0 { + return poll.Success() + } + return poll.Continue("waiting for events to be removed from database. Currently present PDU: %d EDU: %d", len(pduData), len(eduData)) + } + poll.WaitOn(t, checkRetry, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond)) + }) +} diff --git a/federationapi/statistics/statistics.go b/federationapi/statistics/statistics.go index db6d5c735..2ba99112c 100644 --- a/federationapi/statistics/statistics.go +++ b/federationapi/statistics/statistics.go @@ -2,6 +2,7 @@ package statistics import ( "math" + "math/rand" "sync" "time" @@ -20,12 +21,23 @@ type Statistics struct { servers map[gomatrixserverlib.ServerName]*ServerStatistics mutex sync.RWMutex + backoffTimers map[gomatrixserverlib.ServerName]*time.Timer + backoffMutex sync.RWMutex + // How many times should we tolerate consecutive failures before we // just blacklist the host altogether? The backoff is exponential, // so the max time here to attempt is 2**failures seconds. FailuresUntilBlacklist uint32 } +func NewStatistics(db storage.Database, failuresUntilBlacklist uint32) Statistics { + return Statistics{ + DB: db, + FailuresUntilBlacklist: failuresUntilBlacklist, + backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer), + } +} + // ForServer returns server statistics for the given server name. If it // does not exist, it will create empty statistics and return those. func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerStatistics { @@ -45,7 +57,6 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS server = &ServerStatistics{ statistics: s, serverName: serverName, - interrupt: make(chan struct{}), } s.servers[serverName] = server s.mutex.Unlock() @@ -64,29 +75,43 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS // many times we failed etc. It also manages the backoff time and black- // listing a remote host if it remains uncooperative. type ServerStatistics struct { - statistics *Statistics // - serverName gomatrixserverlib.ServerName // - blacklisted atomic.Bool // is the node blacklisted - backoffStarted atomic.Bool // is the backoff started - backoffUntil atomic.Value // time.Time until this backoff interval ends - backoffCount atomic.Uint32 // number of times BackoffDuration has been called - interrupt chan struct{} // interrupts the backoff goroutine - successCounter atomic.Uint32 // how many times have we succeeded? + statistics *Statistics // + serverName gomatrixserverlib.ServerName // + blacklisted atomic.Bool // is the node blacklisted + backoffStarted atomic.Bool // is the backoff started + backoffUntil atomic.Value // time.Time until this backoff interval ends + backoffCount atomic.Uint32 // number of times BackoffDuration has been called + successCounter atomic.Uint32 // how many times have we succeeded? + backoffNotifier func() // notifies destination queue when backoff completes + notifierMutex sync.Mutex } +const maxJitterMultiplier = 1.4 +const minJitterMultiplier = 0.8 + // duration returns how long the next backoff interval should be. func (s *ServerStatistics) duration(count uint32) time.Duration { - return time.Second * time.Duration(math.Exp2(float64(count))) + // Add some jitter to minimise the chance of having multiple backoffs + // ending at the same time. + jitter := rand.Float64()*(maxJitterMultiplier-minJitterMultiplier) + minJitterMultiplier + duration := time.Millisecond * time.Duration(math.Exp2(float64(count))*jitter*1000) + return duration } // cancel will interrupt the currently active backoff. func (s *ServerStatistics) cancel() { s.blacklisted.Store(false) s.backoffUntil.Store(time.Time{}) - select { - case s.interrupt <- struct{}{}: - default: - } + + s.ClearBackoff() +} + +// AssignBackoffNotifier configures the channel to send to when +// a backoff completes. +func (s *ServerStatistics) AssignBackoffNotifier(notifier func()) { + s.notifierMutex.Lock() + defer s.notifierMutex.Unlock() + s.backoffNotifier = notifier } // Success updates the server statistics with a new successful @@ -95,8 +120,8 @@ func (s *ServerStatistics) cancel() { // we will unblacklist it. func (s *ServerStatistics) Success() { s.cancel() - s.successCounter.Inc() s.backoffCount.Store(0) + s.successCounter.Inc() if s.statistics.DB != nil { if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) @@ -105,13 +130,17 @@ func (s *ServerStatistics) Success() { } // Failure marks a failure and starts backing off if needed. -// The next call to BackoffIfRequired will do the right thing -// after this. It will return the time that the current failure +// It will return the time that the current failure // will result in backoff waiting until, and a bool signalling // whether we have blacklisted and therefore to give up. func (s *ServerStatistics) Failure() (time.Time, bool) { + // Return immediately if we have blacklisted this node. + if s.blacklisted.Load() { + return time.Time{}, true + } + // If we aren't already backing off, this call will start - // a new backoff period. Increase the failure counter and + // a new backoff period, increase the failure counter and // start a goroutine which will wait out the backoff and // unset the backoffStarted flag when done. if s.backoffStarted.CompareAndSwap(false, true) { @@ -122,40 +151,48 @@ func (s *ServerStatistics) Failure() (time.Time, bool) { logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName) } } + s.ClearBackoff() return time.Time{}, true } - go func() { - until, ok := s.backoffUntil.Load().(time.Time) - if ok && !until.IsZero() { - select { - case <-time.After(time.Until(until)): - case <-s.interrupt: - } - s.backoffStarted.Store(false) - } - }() + // We're starting a new back off so work out what the next interval + // will be. + count := s.backoffCount.Load() + until := time.Now().Add(s.duration(count)) + s.backoffUntil.Store(until) + + s.statistics.backoffMutex.Lock() + defer s.statistics.backoffMutex.Unlock() + s.statistics.backoffTimers[s.serverName] = time.AfterFunc(time.Until(until), s.backoffFinished) } - // Check if we have blacklisted this node. - if s.blacklisted.Load() { - return time.Now(), true - } + return s.backoffUntil.Load().(time.Time), false +} - // If we're already backing off and we haven't yet surpassed - // the deadline then return that. Repeated calls to Failure - // within a single backoff interval will have no side effects. - if until, ok := s.backoffUntil.Load().(time.Time); ok && !time.Now().After(until) { - return until, false +// ClearBackoff stops the backoff timer for this destination if it is running +// and removes the timer from the backoffTimers map. +func (s *ServerStatistics) ClearBackoff() { + // If the timer is still running then stop it so it's memory is cleaned up sooner. + s.statistics.backoffMutex.Lock() + defer s.statistics.backoffMutex.Unlock() + if timer, ok := s.statistics.backoffTimers[s.serverName]; ok { + timer.Stop() } + delete(s.statistics.backoffTimers, s.serverName) - // We're either backing off and have passed the deadline, or - // we aren't backing off, so work out what the next interval - // will be. - count := s.backoffCount.Load() - until := time.Now().Add(s.duration(count)) - s.backoffUntil.Store(until) - return until, false + s.backoffStarted.Store(false) +} + +// backoffFinished will clear the previous backoff and notify the destination queue. +func (s *ServerStatistics) backoffFinished() { + s.ClearBackoff() + + // Notify the destinationQueue if one is currently running. + s.notifierMutex.Lock() + defer s.notifierMutex.Unlock() + if s.backoffNotifier != nil { + s.backoffNotifier() + } } // BackoffInfo returns information about the current or previous backoff. @@ -174,6 +211,12 @@ func (s *ServerStatistics) Blacklisted() bool { return s.blacklisted.Load() } +// RemoveBlacklist removes the blacklisted status from the server. +func (s *ServerStatistics) RemoveBlacklist() { + s.cancel() + s.backoffCount.Store(0) +} + // SuccessCount returns the number of successful requests. This is // usually useful in constructing transaction IDs. func (s *ServerStatistics) SuccessCount() uint32 { diff --git a/federationapi/statistics/statistics_test.go b/federationapi/statistics/statistics_test.go index 225350b6d..6aa997f44 100644 --- a/federationapi/statistics/statistics_test.go +++ b/federationapi/statistics/statistics_test.go @@ -7,9 +7,7 @@ import ( ) func TestBackoff(t *testing.T) { - stats := Statistics{ - FailuresUntilBlacklist: 7, - } + stats := NewStatistics(nil, 7) server := ServerStatistics{ statistics: &stats, serverName: "test.com", @@ -36,7 +34,7 @@ func TestBackoff(t *testing.T) { // Get the duration. _, blacklist := server.BackoffInfo() - duration := time.Until(until).Round(time.Second) + duration := time.Until(until) // Unset the backoff, or otherwise our next call will think that // there's a backoff in progress and return the same result. @@ -57,8 +55,17 @@ func TestBackoff(t *testing.T) { // Check if the duration is what we expect. t.Logf("Backoff %d is for %s", i, duration) - if wanted := time.Second * time.Duration(math.Exp2(float64(i))); !blacklist && duration != wanted { - t.Fatalf("Backoff %d should have been %s but was %s", i, wanted, duration) + roundingAllowance := 0.01 + minDuration := time.Millisecond * time.Duration(math.Exp2(float64(i))*minJitterMultiplier*1000-roundingAllowance) + maxDuration := time.Millisecond * time.Duration(math.Exp2(float64(i))*maxJitterMultiplier*1000+roundingAllowance) + var inJitterRange bool + if duration >= minDuration && duration <= maxDuration { + inJitterRange = true + } else { + inJitterRange = false + } + if !blacklist && !inJitterRange { + t.Fatalf("Backoff %d should have been between %s and %s but was %s", i, minDuration, maxDuration, duration) } } } diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index b8109b432..09098cd1e 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -18,9 +18,10 @@ import ( "context" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/federationapi/types" - "github.com/matrix-org/gomatrixserverlib" ) type Database interface { @@ -38,8 +39,8 @@ type Database interface { GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) - AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error - AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error + AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error + AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 9e40f311c..6afb313a8 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -52,6 +52,10 @@ type Receipt struct { nid int64 } +func NewReceipt(nid int64) Receipt { + return Receipt{nid: nid} +} + func (r *Receipt) String() string { return fmt.Sprintf("%d", r.nid) } diff --git a/federationapi/storage/shared/storage_edus.go b/federationapi/storage/shared/storage_edus.go index e0c740c11..c796d2f8f 100644 --- a/federationapi/storage/shared/storage_edus.go +++ b/federationapi/storage/shared/storage_edus.go @@ -38,9 +38,9 @@ var defaultExpireEDUTypes = map[string]time.Duration{ // AssociateEDUWithDestination creates an association that the // destination queues will use to determine which JSON blobs to send // to which servers. -func (d *Database) AssociateEDUWithDestination( +func (d *Database) AssociateEDUWithDestinations( ctx context.Context, - serverName gomatrixserverlib.ServerName, + destinations map[gomatrixserverlib.ServerName]struct{}, receipt *Receipt, eduType string, expireEDUTypes map[string]time.Duration, @@ -59,17 +59,18 @@ func (d *Database) AssociateEDUWithDestination( expiresAt = 0 } return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - if err := d.FederationQueueEDUs.InsertQueueEDU( - ctx, // context - txn, // SQL transaction - eduType, // EDU type for coalescing - serverName, // destination server name - receipt.nid, // NID from the federationapi_queue_json table - expiresAt, // The timestamp this EDU will expire - ); err != nil { - return fmt.Errorf("InsertQueueEDU: %w", err) + var err error + for destination := range destinations { + err = d.FederationQueueEDUs.InsertQueueEDU( + ctx, // context + txn, // SQL transaction + eduType, // EDU type for coalescing + destination, // destination server name + receipt.nid, // NID from the federationapi_queue_json table + expiresAt, // The timestamp this EDU will expire + ) } - return nil + return err }) } diff --git a/federationapi/storage/shared/storage_pdus.go b/federationapi/storage/shared/storage_pdus.go index 5a12c388a..dc37d7507 100644 --- a/federationapi/storage/shared/storage_pdus.go +++ b/federationapi/storage/shared/storage_pdus.go @@ -27,23 +27,23 @@ import ( // AssociatePDUWithDestination creates an association that the // destination queues will use to determine which JSON blobs to send // to which servers. -func (d *Database) AssociatePDUWithDestination( +func (d *Database) AssociatePDUWithDestinations( ctx context.Context, - transactionID gomatrixserverlib.TransactionID, - serverName gomatrixserverlib.ServerName, + destinations map[gomatrixserverlib.ServerName]struct{}, receipt *Receipt, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - if err := d.FederationQueuePDUs.InsertQueuePDU( - ctx, // context - txn, // SQL transaction - transactionID, // transaction ID - serverName, // destination server name - receipt.nid, // NID from the federationapi_queue_json table - ); err != nil { - return fmt.Errorf("InsertQueuePDU: %w", err) + var err error + for destination := range destinations { + err = d.FederationQueuePDUs.InsertQueuePDU( + ctx, // context + txn, // SQL transaction + "", // transaction ID + destination, // destination server name + receipt.nid, // NID from the federationapi_queue_json table + ) } - return nil + return err }) } diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go index 3b0268e55..6272fd2b1 100644 --- a/federationapi/storage/storage_test.go +++ b/federationapi/storage/storage_test.go @@ -35,6 +35,7 @@ func TestExpireEDUs(t *testing.T) { } ctx := context.Background() + destinations := map[gomatrixserverlib.ServerName]struct{}{"localhost": {}} test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateFederationDatabase(t, dbType) defer close() @@ -43,7 +44,7 @@ func TestExpireEDUs(t *testing.T) { receipt, err := db.StoreJSON(ctx, "{}") assert.NoError(t, err) - err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, gomatrixserverlib.MReceipt, expireEDUTypes) + err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, gomatrixserverlib.MReceipt, expireEDUTypes) assert.NoError(t, err) } // add data without expiry @@ -51,7 +52,7 @@ func TestExpireEDUs(t *testing.T) { assert.NoError(t, err) // m.read_marker gets the default expiry of 24h, so won't be deleted further down in this test - err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, "m.read_marker", expireEDUTypes) + err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, "m.read_marker", expireEDUTypes) assert.NoError(t, err) // Delete expired EDUs @@ -67,7 +68,7 @@ func TestExpireEDUs(t *testing.T) { receipt, err = db.StoreJSON(ctx, "{}") assert.NoError(t, err) - err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, gomatrixserverlib.MDirectToDevice, expireEDUTypes) + err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, gomatrixserverlib.MDirectToDevice, expireEDUTypes) assert.NoError(t, err) err = db.DeleteExpiredEDUs(ctx) diff --git a/go.mod b/go.mod index d0c7fe32c..1a6ca9e58 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20221014061925-a132619fa241 + github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 @@ -50,6 +50,7 @@ require ( golang.org/x/term v0.0.0-20220919170432-7a66f970e087 gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/yaml.v2 v2.4.0 + gotest.tools/v3 v3.4.0 nhooyr.io/websocket v1.8.7 ) @@ -129,7 +130,6 @@ require ( gopkg.in/macaroon.v2 v2.1.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - gotest.tools/v3 v3.4.0 // indirect ) go 1.18 diff --git a/go.sum b/go.sum index c2ca463ce..25195e524 100644 --- a/go.sum +++ b/go.sum @@ -385,8 +385,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221014061925-a132619fa241 h1:e5o68MWeU7wjTvvNKmVo655oCYesoNRoPeBb1Xfz54g= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221014061925-a132619fa241/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a h1:6rJFN5NBuzZ7h5meYkLtXKa6VFZfDc8oVXHd4SDXr5o= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3 h1:lzkSQvBv8TuqKJCPoVwOVvEnARTlua5rrNy/Qw2Vxeo= github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= diff --git a/internal/transactions/transactions.go b/internal/transactions/transactions.go index d2eb0f27f..7ff6f5044 100644 --- a/internal/transactions/transactions.go +++ b/internal/transactions/transactions.go @@ -13,6 +13,8 @@ package transactions import ( + "net/url" + "path/filepath" "sync" "time" @@ -29,6 +31,7 @@ type txnsMap map[CacheKey]*util.JSONResponse type CacheKey struct { AccessToken string TxnID string + Endpoint string } // Cache represents a temporary store for response entries. @@ -57,14 +60,14 @@ func NewWithCleanupPeriod(cleanupPeriod time.Duration) *Cache { return &t } -// FetchTransaction looks up an entry for the (accessToken, txnID) tuple in Cache. +// FetchTransaction looks up an entry for the (accessToken, txnID, req.URL) tuple in Cache. // Looks in both the txnMaps. // Returns (JSON response, true) if txnID is found, else the returned bool is false. -func (t *Cache) FetchTransaction(accessToken, txnID string) (*util.JSONResponse, bool) { +func (t *Cache) FetchTransaction(accessToken, txnID string, u *url.URL) (*util.JSONResponse, bool) { t.RLock() defer t.RUnlock() for _, txns := range t.txnsMaps { - res, ok := txns[CacheKey{accessToken, txnID}] + res, ok := txns[CacheKey{accessToken, txnID, filepath.Dir(u.Path)}] if ok { return res, true } @@ -72,13 +75,12 @@ func (t *Cache) FetchTransaction(accessToken, txnID string) (*util.JSONResponse, return nil, false } -// AddTransaction adds an entry for the (accessToken, txnID) tuple in Cache. +// AddTransaction adds an entry for the (accessToken, txnID, req.URL) tuple in Cache. // Adds to the front txnMap. -func (t *Cache) AddTransaction(accessToken, txnID string, res *util.JSONResponse) { +func (t *Cache) AddTransaction(accessToken, txnID string, u *url.URL, res *util.JSONResponse) { t.Lock() defer t.Unlock() - - t.txnsMaps[0][CacheKey{accessToken, txnID}] = res + t.txnsMaps[0][CacheKey{accessToken, txnID, filepath.Dir(u.Path)}] = res } // cacheCleanService is responsible for cleaning up entries after cleanupPeriod. diff --git a/internal/transactions/transactions_test.go b/internal/transactions/transactions_test.go index aa837f76c..c552550ac 100644 --- a/internal/transactions/transactions_test.go +++ b/internal/transactions/transactions_test.go @@ -14,6 +14,9 @@ package transactions import ( "net/http" + "net/url" + "path/filepath" + "reflect" "strconv" "testing" @@ -24,6 +27,16 @@ type fakeType struct { ID string `json:"ID"` } +func TestCompare(t *testing.T) { + u1, _ := url.Parse("/send/1?accessToken=123") + u2, _ := url.Parse("/send/1") + c1 := CacheKey{"1", "2", filepath.Dir(u1.Path)} + c2 := CacheKey{"1", "2", filepath.Dir(u2.Path)} + if !reflect.DeepEqual(c1, c2) { + t.Fatalf("Cache keys differ: %+v <> %+v", c1, c2) + } +} + var ( fakeAccessToken = "aRandomAccessToken" fakeAccessToken2 = "anotherRandomAccessToken" @@ -34,23 +47,28 @@ var ( fakeResponse2 = &util.JSONResponse{ Code: http.StatusOK, JSON: fakeType{ID: "1"}, } + fakeResponse3 = &util.JSONResponse{ + Code: http.StatusOK, JSON: fakeType{ID: "2"}, + } ) // TestCache creates a New Cache and tests AddTransaction & FetchTransaction func TestCache(t *testing.T) { fakeTxnCache := New() - fakeTxnCache.AddTransaction(fakeAccessToken, fakeTxnID, fakeResponse) + u, _ := url.Parse("") + fakeTxnCache.AddTransaction(fakeAccessToken, fakeTxnID, u, fakeResponse) // Add entries for noise. for i := 1; i <= 100; i++ { fakeTxnCache.AddTransaction( fakeAccessToken, fakeTxnID+strconv.Itoa(i), + u, &util.JSONResponse{Code: http.StatusOK, JSON: fakeType{ID: strconv.Itoa(i)}}, ) } - testResponse, ok := fakeTxnCache.FetchTransaction(fakeAccessToken, fakeTxnID) + testResponse, ok := fakeTxnCache.FetchTransaction(fakeAccessToken, fakeTxnID, u) if !ok { t.Error("Failed to retrieve entry for txnID: ", fakeTxnID) } else if testResponse.JSON != fakeResponse.JSON { @@ -59,20 +77,30 @@ func TestCache(t *testing.T) { } // TestCacheScope ensures transactions with the same transaction ID are not shared -// across multiple access tokens. +// across multiple access tokens and endpoints. func TestCacheScope(t *testing.T) { cache := New() - cache.AddTransaction(fakeAccessToken, fakeTxnID, fakeResponse) - cache.AddTransaction(fakeAccessToken2, fakeTxnID, fakeResponse2) + sendEndpoint, _ := url.Parse("/send/1?accessToken=test") + sendToDeviceEndpoint, _ := url.Parse("/sendToDevice/1") + cache.AddTransaction(fakeAccessToken, fakeTxnID, sendEndpoint, fakeResponse) + cache.AddTransaction(fakeAccessToken2, fakeTxnID, sendEndpoint, fakeResponse2) + cache.AddTransaction(fakeAccessToken2, fakeTxnID, sendToDeviceEndpoint, fakeResponse3) - if res, ok := cache.FetchTransaction(fakeAccessToken, fakeTxnID); !ok { + if res, ok := cache.FetchTransaction(fakeAccessToken, fakeTxnID, sendEndpoint); !ok { t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID) } else if res.JSON != fakeResponse.JSON { t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse.JSON, res.JSON) } - if res, ok := cache.FetchTransaction(fakeAccessToken2, fakeTxnID); !ok { + if res, ok := cache.FetchTransaction(fakeAccessToken2, fakeTxnID, sendEndpoint); !ok { t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID) } else if res.JSON != fakeResponse2.JSON { t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse2.JSON, res.JSON) } + + // Ensure the txnID is not shared across endpoints + if res, ok := cache.FetchTransaction(fakeAccessToken2, fakeTxnID, sendToDeviceEndpoint); !ok { + t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID) + } else if res.JSON != fakeResponse3.JSON { + t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse2.JSON, res.JSON) + } } diff --git a/internal/version.go b/internal/version.go index c888748a8..5d739a45d 100644 --- a/internal/version.go +++ b/internal/version.go @@ -17,7 +17,7 @@ var build string const ( VersionMajor = 0 VersionMinor = 10 - VersionPatch = 3 + VersionPatch = 4 VersionTag = "" // example: "rc1" ) diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 06fc4987c..49ef03054 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -250,6 +250,7 @@ func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *ap // nolint:gocyclo func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error { + var respMu sync.Mutex res.DeviceKeys = make(map[string]map[string]json.RawMessage) res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey) res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey) @@ -329,7 +330,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques } // attempt to satisfy key queries from the local database first as we should get device updates pushed to us - domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys) + domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, &respMu, domainToDeviceKeys) if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 { // perform key queries for remote devices a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys) @@ -407,7 +408,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques } func (a *KeyInternalAPI) remoteKeysFromDatabase( - ctx context.Context, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string, + ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string, ) map[string]map[string][]string { fetchRemote := make(map[string]map[string][]string) for domain, userToDeviceMap := range domainToDeviceKeys { @@ -415,7 +416,7 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase( // we can't safely return keys from the db when all devices are requested as we don't // know if one has just been added. if len(deviceIDs) > 0 { - err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, deviceIDs) + err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, deviceIDs) if err == nil { continue } @@ -471,7 +472,9 @@ func (a *KeyInternalAPI) queryRemoteKeys( close(resultCh) }() - for result := range resultCh { + processResult := func(result *gomatrixserverlib.RespQueryKeys) { + respMu.Lock() + defer respMu.Unlock() for userID, nest := range result.DeviceKeys { res.DeviceKeys[userID] = make(map[string]json.RawMessage) for deviceID, deviceKey := range nest { @@ -494,6 +497,10 @@ func (a *KeyInternalAPI) queryRemoteKeys( // TODO: do we want to persist these somewhere now // that we have fetched them? } + + for result := range resultCh { + processResult(result) + } } func (a *KeyInternalAPI) queryRemoteKeysOnServer( @@ -541,9 +548,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( } // refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this // user so the fact that we're populating all devices here isn't a problem so long as we have devices. - respMu.Lock() - err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, nil) - respMu.Unlock() + err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil) if err != nil { logrus.WithFields(logrus.Fields{ logrus.ErrorKey: err, @@ -567,25 +572,26 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( res.Failures[serverName] = map[string]interface{}{ "message": err.Error(), } + respMu.Unlock() // last ditch, use the cache only. This is good for when clients hit /keys/query and the remote server // is down, better to return something than nothing at all. Clients can know about the failure by // inspecting the failures map though so they can know it's a cached response. for userID, dkeys := range devKeys { // drop the error as it's already a failure at this point - _ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, dkeys) + _ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, dkeys) } // Sytest expects no failures, if we still could retrieve keys, e.g. from local cache + respMu.Lock() if len(res.DeviceKeys) > 0 { delete(res.Failures, serverName) } respMu.Unlock() - } func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( - ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string, + ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string, ) error { keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) // if we can't query the db or there are fewer keys than requested, fetch from remote. @@ -598,9 +604,11 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( if len(deviceIDs) == 0 && len(keys) == 0 { return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID) } + respMu.Lock() if res.DeviceKeys[userID] == nil { res.DeviceKeys[userID] = make(map[string]json.RawMessage) } + respMu.Unlock() for _, key := range keys { if len(key.KeyJSON) == 0 { @@ -610,7 +618,9 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct { DisplayName string `json:"device_display_name,omitempty"` }{key.DisplayName}) + respMu.Lock() res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON + respMu.Unlock() } return nil } diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index cfbb05327..f767615c8 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -428,6 +428,13 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent( return } + // Only notify clients about retired invite events, if the user didn't accept the invite. + // The PDU stream will also receive an event about accepting the invitation, so there should + // be a "smooth" transition from invite -> join, and not invite -> leave -> join + if msg.Membership == gomatrixserverlib.Join { + return + } + // Notify any active sync requests that the invite has been retired. s.inviteStream.Advance(pduPos) s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, msg.TargetUserID) diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index b562e6804..0ecbdf4d2 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -28,8 +28,9 @@ import ( "github.com/matrix-org/dendrite/syncapi/types" "github.com/lib/pq" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/sqlutil" ) const outputRoomEventsSchema = ` @@ -133,7 +134,7 @@ const updateEventJSONSQL = "" + "UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2" // In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id). -const selectStateInRangeSQL = "" + +const selectStateInRangeFilteredSQL = "" + "SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids, history_visibility" + " FROM syncapi_output_room_events" + " WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" + @@ -146,6 +147,15 @@ const selectStateInRangeSQL = "" + " ORDER BY id ASC" + " LIMIT $9" +// In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id). +const selectStateInRangeSQL = "" + + "SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids, history_visibility" + + " FROM syncapi_output_room_events" + + " WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" + + " AND room_id = ANY($3)" + + " ORDER BY id ASC" + + " LIMIT $4" + const deleteEventsForRoomSQL = "" + "DELETE FROM syncapi_output_room_events WHERE room_id = $1" @@ -171,20 +181,21 @@ const selectContextAfterEventSQL = "" + const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE id > $1 AND type = ANY($2) ORDER BY id ASC LIMIT $3" type outputRoomEventsStatements struct { - insertEventStmt *sql.Stmt - selectEventsStmt *sql.Stmt - selectEventsWitFilterStmt *sql.Stmt - selectMaxEventIDStmt *sql.Stmt - selectRecentEventsStmt *sql.Stmt - selectRecentEventsForSyncStmt *sql.Stmt - selectEarlyEventsStmt *sql.Stmt - selectStateInRangeStmt *sql.Stmt - updateEventJSONStmt *sql.Stmt - deleteEventsForRoomStmt *sql.Stmt - selectContextEventStmt *sql.Stmt - selectContextBeforeEventStmt *sql.Stmt - selectContextAfterEventStmt *sql.Stmt - selectSearchStmt *sql.Stmt + insertEventStmt *sql.Stmt + selectEventsStmt *sql.Stmt + selectEventsWitFilterStmt *sql.Stmt + selectMaxEventIDStmt *sql.Stmt + selectRecentEventsStmt *sql.Stmt + selectRecentEventsForSyncStmt *sql.Stmt + selectEarlyEventsStmt *sql.Stmt + selectStateInRangeFilteredStmt *sql.Stmt + selectStateInRangeStmt *sql.Stmt + updateEventJSONStmt *sql.Stmt + deleteEventsForRoomStmt *sql.Stmt + selectContextEventStmt *sql.Stmt + selectContextBeforeEventStmt *sql.Stmt + selectContextAfterEventStmt *sql.Stmt + selectSearchStmt *sql.Stmt } func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { @@ -214,6 +225,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { {&s.selectRecentEventsStmt, selectRecentEventsSQL}, {&s.selectRecentEventsForSyncStmt, selectRecentEventsForSyncSQL}, {&s.selectEarlyEventsStmt, selectEarlyEventsSQL}, + {&s.selectStateInRangeFilteredStmt, selectStateInRangeFilteredSQL}, {&s.selectStateInRangeStmt, selectStateInRangeSQL}, {&s.updateEventJSONStmt, updateEventJSONSQL}, {&s.deleteEventsForRoomStmt, deleteEventsForRoomSQL}, @@ -240,17 +252,28 @@ func (s *outputRoomEventsStatements) SelectStateInRange( ctx context.Context, txn *sql.Tx, r types.Range, stateFilter *gomatrixserverlib.StateFilter, roomIDs []string, ) (map[string]map[string]bool, map[string]types.StreamEvent, error) { - stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt) - senders, notSenders := getSendersStateFilterFilter(stateFilter) - rows, err := stmt.QueryContext( - ctx, r.Low(), r.High(), pq.StringArray(roomIDs), - pq.StringArray(senders), - pq.StringArray(notSenders), - pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)), - pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)), - stateFilter.ContainsURL, - stateFilter.Limit, - ) + var rows *sql.Rows + var err error + if stateFilter != nil { + stmt := sqlutil.TxStmt(txn, s.selectStateInRangeFilteredStmt) + senders, notSenders := getSendersStateFilterFilter(stateFilter) + rows, err = stmt.QueryContext( + ctx, r.Low(), r.High(), pq.StringArray(roomIDs), + pq.StringArray(senders), + pq.StringArray(notSenders), + pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)), + pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)), + stateFilter.ContainsURL, + stateFilter.Limit, + ) + } else { + stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt) + rows, err = stmt.QueryContext( + ctx, r.Low(), r.High(), pq.StringArray(roomIDs), + r.High()-r.Low(), + ) + } + if err != nil { return nil, nil, err } diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index cb61c1c26..1f66ccc0e 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -5,10 +5,11 @@ import ( "database/sql" "fmt" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" ) type DatabaseTransaction struct { @@ -277,6 +278,7 @@ func (d *DatabaseTransaction) GetBackwardTopologyPos( // exclusive of oldPos, inclusive of newPos, for the rooms in which // the user has new membership events. // A list of joined room IDs is also returned in case the caller needs it. +// nolint:gocyclo func (d *DatabaseTransaction) GetStateDeltas( ctx context.Context, device *userapi.Device, r types.Range, userID string, @@ -311,7 +313,7 @@ func (d *DatabaseTransaction) GetStateDeltas( } // get all the state events ever (i.e. for all available rooms) between these two positions - stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs) + stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, nil, allRoomIDs) if err != nil { if err == sql.ErrNoRows { return nil, nil, nil @@ -326,6 +328,22 @@ func (d *DatabaseTransaction) GetStateDeltas( return nil, nil, err } + // get all the state events ever (i.e. for all available rooms) between these two positions + stateNeededFiltered, eventMapFiltered, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } + return nil, nil, err + } + stateFiltered, err := d.fetchStateEvents(ctx, d.txn, stateNeededFiltered, eventMapFiltered) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } + return nil, nil, err + } + // find out which rooms this user is peeking, if any. // We do this before joins so any peeks get overwritten peeks, err := d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, device.ID, r) @@ -371,6 +389,7 @@ func (d *DatabaseTransaction) GetStateDeltas( // If our membership is now join but the previous membership wasn't // then this is a "join transition", so we'll insert this room. if prevMembership != membership { + newlyJoinedRooms[roomID] = true // Get the full room state, as we'll send that down for a newly // joined room instead of a delta. var s []types.StreamEvent @@ -383,8 +402,7 @@ func (d *DatabaseTransaction) GetStateDeltas( // Add the information for this room into the state so that // it will get added with all of the rest of the joined rooms. - state[roomID] = s - newlyJoinedRooms[roomID] = true + stateFiltered[roomID] = s } // We won't add joined rooms into the delta at this point as they @@ -395,7 +413,7 @@ func (d *DatabaseTransaction) GetStateDeltas( deltas = append(deltas, types.StateDelta{ Membership: membership, MembershipPos: ev.StreamPosition, - StateEvents: d.StreamEventsToEvents(device, stateStreamEvents), + StateEvents: d.StreamEventsToEvents(device, stateFiltered[roomID]), RoomID: roomID, }) break @@ -407,7 +425,7 @@ func (d *DatabaseTransaction) GetStateDeltas( for _, joinedRoomID := range joinedRoomIDs { deltas = append(deltas, types.StateDelta{ Membership: gomatrixserverlib.Join, - StateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]), + StateEvents: d.StreamEventsToEvents(device, stateFiltered[joinedRoomID]), RoomID: joinedRoomID, NewlyJoined: newlyJoinedRooms[joinedRoomID], }) diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index d6a674b9c..77c692ff0 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -29,8 +29,9 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/sqlutil" ) const outputRoomEventsSchema = ` @@ -189,21 +190,36 @@ func (s *outputRoomEventsStatements) SelectStateInRange( for _, roomID := range roomIDs { inputParams = append(inputParams, roomID) } - stmt, params, err := prepareWithFilters( - s.db, txn, stmtSQL, inputParams, - stateFilter.Senders, stateFilter.NotSenders, - stateFilter.Types, stateFilter.NotTypes, - nil, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderAsc, + var ( + stmt *sql.Stmt + params []any + err error ) + if stateFilter != nil { + stmt, params, err = prepareWithFilters( + s.db, txn, stmtSQL, inputParams, + stateFilter.Senders, stateFilter.NotSenders, + stateFilter.Types, stateFilter.NotTypes, + nil, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderAsc, + ) + } else { + stmt, params, err = prepareWithFilters( + s.db, txn, stmtSQL, inputParams, + nil, nil, + nil, nil, + nil, nil, int(r.High()-r.Low()), FilterOrderAsc, + ) + } if err != nil { return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err) } + defer internal.CloseAndLogIfError(ctx, stmt, "selectStateInRange: stmt.close() failed") rows, err := stmt.QueryContext(ctx, params...) if err != nil { return nil, nil, err } - defer rows.Close() // nolint: errcheck + defer internal.CloseAndLogIfError(ctx, rows, "selectStateInRange: rows.close() failed") // Fetch all the state change events for all rooms between the two positions then loop each event and: // - Keep a cache of the event by ID (99% of state change events are for the event itself) // - For each room ID, build up an array of event IDs which represents cumulative adds/removes @@ -269,6 +285,7 @@ func (s *outputRoomEventsStatements) SelectMaxEventID( ) (id int64, err error) { var nullableID sql.NullInt64 stmt := sqlutil.TxStmt(txn, s.selectMaxEventIDStmt) + defer internal.CloseAndLogIfError(ctx, stmt, "SelectMaxEventID: stmt.close() failed") err = stmt.QueryRowContext(ctx).Scan(&nullableID) if nullableID.Valid { id = nullableID.Int64 @@ -323,6 +340,7 @@ func (s *outputRoomEventsStatements) InsertEvent( return 0, err } insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) + defer internal.CloseAndLogIfError(ctx, insertStmt, "InsertEvent: stmt.close() failed") _, err = insertStmt.ExecContext( ctx, streamPos, @@ -367,6 +385,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( if err != nil { return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err) } + defer internal.CloseAndLogIfError(ctx, stmt, "selectRecentEvents: stmt.close() failed") rows, err := stmt.QueryContext(ctx, params...) if err != nil { @@ -415,6 +434,8 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( if err != nil { return nil, fmt.Errorf("s.prepareWithFilters: %w", err) } + defer internal.CloseAndLogIfError(ctx, stmt, "SelectEarlyEvents: stmt.close() failed") + rows, err := stmt.QueryContext(ctx, params...) if err != nil { return nil, err @@ -456,6 +477,8 @@ func (s *outputRoomEventsStatements) SelectEvents( if err != nil { return nil, err } + defer internal.CloseAndLogIfError(ctx, stmt, "SelectEvents: stmt.close() failed") + rows, err := stmt.QueryContext(ctx, params...) if err != nil { return nil, err @@ -558,6 +581,10 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent( filter.Types, filter.NotTypes, nil, filter.ContainsURL, filter.Limit, FilterOrderDesc, ) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, stmt, "SelectContextBeforeEvent: stmt.close() failed") rows, err := stmt.QueryContext(ctx, params...) if err != nil { @@ -596,6 +623,10 @@ func (s *outputRoomEventsStatements) SelectContextAfterEvent( filter.Types, filter.NotTypes, nil, filter.ContainsURL, filter.Limit, FilterOrderAsc, ) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, stmt, "SelectContextAfterEvent: stmt.close() failed") rows, err := stmt.QueryContext(ctx, params...) if err != nil { diff --git a/syncapi/storage/tables/memberships_test.go b/syncapi/storage/tables/memberships_test.go new file mode 100644 index 000000000..0cee7f5a5 --- /dev/null +++ b/syncapi/storage/tables/memberships_test.go @@ -0,0 +1,198 @@ +package tables_test + +import ( + "context" + "database/sql" + "reflect" + "sort" + "testing" + "time" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/storage/postgres" + "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/test" +) + +func newMembershipsTable(t *testing.T, dbType test.DBType) (tables.Memberships, *sql.DB, func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + if err != nil { + t.Fatalf("failed to open db: %s", err) + } + + var tab tables.Memberships + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresMembershipsTable(db) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSqliteMembershipsTable(db) + } + if err != nil { + t.Fatalf("failed to make new table: %s", err) + } + return tab, db, close +} + +func TestMembershipsTable(t *testing.T) { + + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + + // Create users + var userEvents []*gomatrixserverlib.HeaderedEvent + users := []string{alice.ID} + for _, x := range room.CurrentState() { + if x.StateKeyEquals(alice.ID) { + if _, err := x.Membership(); err == nil { + userEvents = append(userEvents, x) + break + } + } + } + + if len(userEvents) == 0 { + t.Fatalf("didn't find creator membership event") + } + + for i := 0; i < 10; i++ { + u := test.NewUser(t) + users = append(users, u.ID) + + ev := room.CreateAndInsert(t, u, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(u.ID)) + userEvents = append(userEvents, ev) + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + table, _, close := newMembershipsTable(t, dbType) + defer close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + for _, ev := range userEvents { + if err := table.UpsertMembership(ctx, nil, ev, types.StreamPosition(ev.Depth()), 1); err != nil { + t.Fatalf("failed to upsert membership: %s", err) + } + } + + testUpsert(t, ctx, table, userEvents[0], alice, room) + testMembershipCount(t, ctx, table, room) + testHeroes(t, ctx, table, alice, room, users) + }) +} + +func testHeroes(t *testing.T, ctx context.Context, table tables.Memberships, user *test.User, room *test.Room, users []string) { + + // Re-slice and sort the expected users + users = users[1:] + sort.Strings(users) + type testCase struct { + name string + memberships []string + wantHeroes []string + } + + testCases := []testCase{ + {name: "no memberships queried", memberships: []string{}}, + {name: "joined memberships queried should be limited", memberships: []string{gomatrixserverlib.Join}, wantHeroes: users[:5]}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := table.SelectHeroes(ctx, nil, room.ID, user.ID, tc.memberships) + if err != nil { + t.Fatalf("unable to select heroes: %s", err) + } + if gotLen := len(got); gotLen != len(tc.wantHeroes) { + t.Fatalf("expected %d heroes, got %d", len(tc.wantHeroes), gotLen) + } + + if !reflect.DeepEqual(got, tc.wantHeroes) { + t.Fatalf("expected heroes to be %+v, got %+v", tc.wantHeroes, got) + } + }) + } +} + +func testMembershipCount(t *testing.T, ctx context.Context, table tables.Memberships, room *test.Room) { + t.Run("membership counts are correct", func(t *testing.T) { + // After 10 events, we should have 6 users (5 create related [incl. one member event], 5 member events = 6 users) + count, err := table.SelectMembershipCount(ctx, nil, room.ID, gomatrixserverlib.Join, 10) + if err != nil { + t.Fatalf("failed to get membership count: %s", err) + } + expectedCount := 6 + if expectedCount != count { + t.Fatalf("expected member count to be %d, got %d", expectedCount, count) + } + + // After 100 events, we should have all 11 users + count, err = table.SelectMembershipCount(ctx, nil, room.ID, gomatrixserverlib.Join, 100) + if err != nil { + t.Fatalf("failed to get membership count: %s", err) + } + expectedCount = 11 + if expectedCount != count { + t.Fatalf("expected member count to be %d, got %d", expectedCount, count) + } + }) +} + +func testUpsert(t *testing.T, ctx context.Context, table tables.Memberships, membershipEvent *gomatrixserverlib.HeaderedEvent, user *test.User, room *test.Room) { + t.Run("upserting works as expected", func(t *testing.T) { + if err := table.UpsertMembership(ctx, nil, membershipEvent, 1, 1); err != nil { + t.Fatalf("failed to upsert membership: %s", err) + } + membership, pos, err := table.SelectMembershipForUser(ctx, nil, room.ID, user.ID, 1) + if err != nil { + t.Fatalf("failed to select membership: %s", err) + } + expectedPos := 1 + if pos != expectedPos { + t.Fatalf("expected pos to be %d, got %d", expectedPos, pos) + } + if membership != gomatrixserverlib.Join { + t.Fatalf("expected membership to be join, got %s", membership) + } + // Create a new event which gets upserted and should not cause issues + ev := room.CreateAndInsert(t, user, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": gomatrixserverlib.Join, + }, test.WithStateKey(user.ID)) + // Insert the same event again, but with different positions, which should get updated + if err = table.UpsertMembership(ctx, nil, ev, 2, 2); err != nil { + t.Fatalf("failed to upsert membership: %s", err) + } + + // Verify the position got updated + membership, pos, err = table.SelectMembershipForUser(ctx, nil, room.ID, user.ID, 10) + if err != nil { + t.Fatalf("failed to select membership: %s", err) + } + expectedPos = 2 + if pos != expectedPos { + t.Fatalf("expected pos to be %d, got %d", expectedPos, pos) + } + if membership != gomatrixserverlib.Join { + t.Fatalf("expected membership to be join, got %s", membership) + } + + // If we can't find a membership, it should default to leave + if membership, _, err = table.SelectMembershipForUser(ctx, nil, room.ID, user.ID, 1); err != nil { + t.Fatalf("failed to select membership: %s", err) + } + if membership != gomatrixserverlib.Leave { + t.Fatalf("expected membership to be leave, got %s", membership) + } + }) +} diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index 7875ffa35..700f25c10 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -74,21 +74,26 @@ func (p *InviteStreamProvider) IncrementalSync( return to } for roomID := range retiredInvites { - if _, ok := req.Response.Rooms.Join[roomID]; !ok { - lr := types.NewLeaveResponse() - h := sha256.Sum256(append([]byte(roomID), []byte(strconv.FormatInt(int64(to), 10))...)) - lr.Timeline.Events = append(lr.Timeline.Events, gomatrixserverlib.ClientEvent{ - // fake event ID which muxes in the to position - EventID: "$" + base64.RawURLEncoding.EncodeToString(h[:]), - OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), - RoomID: roomID, - Sender: req.Device.UserID, - StateKey: &req.Device.UserID, - Type: "m.room.member", - Content: gomatrixserverlib.RawJSON(`{"membership":"leave"}`), - }) - req.Response.Rooms.Leave[roomID] = lr + if _, ok := req.Response.Rooms.Invite[roomID]; ok { + continue } + if _, ok := req.Response.Rooms.Join[roomID]; ok { + continue + } + lr := types.NewLeaveResponse() + h := sha256.Sum256(append([]byte(roomID), []byte(strconv.FormatInt(int64(to), 10))...)) + lr.Timeline.Events = append(lr.Timeline.Events, gomatrixserverlib.ClientEvent{ + // fake event ID which muxes in the to position + EventID: "$" + base64.RawURLEncoding.EncodeToString(h[:]), + OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), + RoomID: roomID, + Sender: req.Device.UserID, + StateKey: &req.Device.UserID, + Type: "m.room.member", + Content: gomatrixserverlib.RawJSON(`{"membership":"leave"}`), + }) + req.Response.Rooms.Leave[roomID] = lr + } return maxID diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 613ac434f..9ec2b61cd 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -194,7 +194,7 @@ func (p *PDUStreamProvider) IncrementalSync( } } var pos types.StreamPosition - if pos, err = p.addRoomDeltaToResponse(ctx, snapshot, req.Device, newRange, delta, &eventFilter, &stateFilter, req.Response); err != nil { + if pos, err = p.addRoomDeltaToResponse(ctx, snapshot, req.Device, newRange, delta, &eventFilter, &stateFilter, req); err != nil { req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed") if err == context.DeadlineExceeded || err == context.Canceled || err == sql.ErrTxDone { return newPos @@ -225,7 +225,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( delta types.StateDelta, eventFilter *gomatrixserverlib.RoomEventFilter, stateFilter *gomatrixserverlib.StateFilter, - res *types.Response, + req *types.SyncRequest, ) (types.StreamPosition, error) { if delta.MembershipPos > 0 && delta.Membership == gomatrixserverlib.Leave { // make sure we don't leak recent events after the leave event. @@ -290,8 +290,10 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( hasMembershipChange := false for _, recentEvent := range recentStreamEvents { if recentEvent.Type() == gomatrixserverlib.MRoomMember && recentEvent.StateKey() != nil { + if membership, _ := recentEvent.Membership(); membership == gomatrixserverlib.Join { + req.MembershipChanges[*recentEvent.StateKey()] = struct{}{} + } hasMembershipChange = true - break } } @@ -318,9 +320,9 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. - jr.Timeline.Limited = limited && len(events) == len(recentEvents) + jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync) - res.Rooms.Join[delta.RoomID] = jr + req.Response.Rooms.Join[delta.RoomID] = jr case gomatrixserverlib.Peek: jr := types.NewJoinResponse() @@ -329,7 +331,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = limited jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync) - res.Rooms.Peek[delta.RoomID] = jr + req.Response.Rooms.Peek[delta.RoomID] = jr case gomatrixserverlib.Leave: fallthrough // transitions to leave are the same as ban @@ -342,7 +344,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( // didn't "remove" events, return that the response is limited. lr.Timeline.Limited = limited && len(events) == len(recentEvents) lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync) - res.Rooms.Leave[delta.RoomID] = lr + req.Response.Rooms.Leave[delta.RoomID] = lr } return latestPosition, nil diff --git a/syncapi/sync/request.go b/syncapi/sync/request.go index 52b68a710..6a97e25bc 100644 --- a/syncapi/sync/request.go +++ b/syncapi/sync/request.go @@ -92,15 +92,16 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat }) return &types.SyncRequest{ - Context: req.Context(), // - Log: logger, // - Device: &device, // - Response: types.NewResponse(), // Populated by all streams - Filter: filter, // - Since: since, // - Timeout: timeout, // - Rooms: make(map[string]string), // Populated by the PDU stream - WantFullState: wantFullState, // + Context: req.Context(), // + Log: logger, // + Device: &device, // + Response: types.NewResponse(), // Populated by all streams + Filter: filter, // + Since: since, // + Timeout: timeout, // + Rooms: make(map[string]string), // Populated by the PDU stream + WantFullState: wantFullState, // + MembershipChanges: make(map[string]struct{}), // Populated by the PDU stream }, nil } diff --git a/syncapi/types/provider.go b/syncapi/types/provider.go index 378cafe99..9a533002b 100644 --- a/syncapi/types/provider.go +++ b/syncapi/types/provider.go @@ -4,9 +4,10 @@ import ( "context" "time" - userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" + + userapi "github.com/matrix-org/dendrite/userapi/api" ) type SyncRequest struct { @@ -22,6 +23,8 @@ type SyncRequest struct { // Updated by the PDU stream. Rooms map[string]string // Updated by the PDU stream. + MembershipChanges map[string]struct{} + // Updated by the PDU stream. IgnoredUsers IgnoredUsers } diff --git a/sytest-blacklist b/sytest-blacklist index 2641b1d78..4726be9c9 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -22,10 +22,6 @@ Forgotten room messages cannot be paginated Local device key changes get to remote servers with correct prev_id -# Flakey - -Local device key changes appear in /keys/changes - # we don't support groups Remove group category @@ -39,12 +35,6 @@ Events in rooms with AS-hosted room aliases are sent to AS server Inviting an AS-hosted user asks the AS server Accesing an AS-hosted room alias asks the AS server -# Flakey, need additional investigation - -Messages that notify from another user increment notification_count -Messages that highlight from another user increment unread highlight count -Notifications can be viewed with GET /notifications - # More flakey Guest users can join guest_access rooms diff --git a/sytest-whitelist b/sytest-whitelist index 27dd5688a..e9f46d327 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -746,3 +746,9 @@ Inbound federation can return missing events for joined visibility outliers whose auth_events are in a different room are correctly rejected Messages that notify from another user increment notification_count Messages that highlight from another user increment unread highlight count +Newly joined room has correct timeline in incremental sync +When user joins a room the state is included in the next sync +When user joins a room the state is included in a gapped sync +Messages that notify from another user increment notification_count +Messages that highlight from another user increment unread highlight count +Notifications can be viewed with GET /notifications diff --git a/userapi/api/api.go b/userapi/api/api.go index f1e30aeda..89b2831ee 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -96,7 +96,7 @@ type ClientUserAPI interface { PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error SetAvatarURL(ctx context.Context, req *PerformSetAvatarURLRequest, res *PerformSetAvatarURLResponse) error - SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *struct{}) error + SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *PerformUpdateDisplayNameResponse) error QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error @@ -579,7 +579,10 @@ type Notification struct { type PerformSetAvatarURLRequest struct { Localpart, AvatarURL string } -type PerformSetAvatarURLResponse struct{} +type PerformSetAvatarURLResponse struct { + Profile *authtypes.Profile `json:"profile"` + Changed bool `json:"changed"` +} type QueryNumericLocalpartResponse struct { ID int64 @@ -606,6 +609,11 @@ type PerformUpdateDisplayNameRequest struct { Localpart, DisplayName string } +type PerformUpdateDisplayNameResponse struct { + Profile *authtypes.Profile `json:"profile"` + Changed bool `json:"changed"` +} + type QueryLocalpartForThreePIDRequest struct { ThreePID, Medium string } diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go index 7e2f69615..90834f7e3 100644 --- a/userapi/api/api_trace.go +++ b/userapi/api/api_trace.go @@ -168,7 +168,7 @@ func (t *UserInternalAPITrace) QueryAccountAvailability(ctx context.Context, req return err } -func (t *UserInternalAPITrace) SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *struct{}) error { +func (t *UserInternalAPITrace) SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *PerformUpdateDisplayNameResponse) error { err := t.Impl.SetDisplayName(ctx, req, res) util.GetLogger(ctx).Infof("SetDisplayName req=%+v res=%+v", js(req), js(res)) return err diff --git a/userapi/consumers/clientapi.go b/userapi/consumers/clientapi.go index c220d35cb..79f1bf06f 100644 --- a/userapi/consumers/clientapi.go +++ b/userapi/consumers/clientapi.go @@ -81,7 +81,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats readPos := msg.Header.Get(jetstream.EventID) evType := msg.Header.Get("type") - if readPos == "" || evType != "m.read" { + if readPos == "" || (evType != "m.read" && evType != "m.read.private") { return true } diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index 3bbeb439a..e4587670f 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -10,19 +10,24 @@ import ( "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi/storage" ) func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { + base, baseclose := testrig.CreateBaseDendrite(t, dbType) t.Helper() connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{ + db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "", 4, 0, 0, "") if err != nil { t.Fatalf("failed to create new user db: %v", err) } - return db, close + return db, func() { + close() + baseclose() + } } func mustCreateEvent(t *testing.T, content string) *gomatrixserverlib.HeaderedEvent { diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 2f7795dfe..63044eedb 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -170,7 +170,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P return nil } - if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { + if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { return err } @@ -813,7 +813,10 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush } func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error { - return a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL) + profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL) + res.Profile = profile + res.Changed = changed + return err } func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error { @@ -847,8 +850,11 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q } } -func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, _ *struct{}) error { - return a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName) +func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error { + profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName) + res.Profile = profile + res.Changed = changed + return err } func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error { diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index a375d6caa..aa5d46d9f 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -388,7 +388,7 @@ func (h *httpUserInternalAPI) QueryAccountByPassword( func (h *httpUserInternalAPI) SetDisplayName( ctx context.Context, request *api.PerformUpdateDisplayNameRequest, - response *struct{}, + response *api.PerformUpdateDisplayNameResponse, ) error { return httputil.CallInternalRPCAPI( "SetDisplayName", h.apiURL+PerformSetDisplayNamePath, diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 02efe7afe..fb12b53af 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -29,8 +29,8 @@ import ( type Profile interface { GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) - SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error - SetDisplayName(ctx context.Context, localpart string, displayName string) error + SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, bool, error) + SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, bool, error) } type Account interface { diff --git a/userapi/storage/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go index 67113367b..0b6a3af6d 100644 --- a/userapi/storage/postgres/account_data_table.go +++ b/userapi/storage/postgres/account_data_table.go @@ -26,7 +26,7 @@ import ( const accountDataSchema = ` -- Stores data about accounts data. -CREATE TABLE IF NOT EXISTS account_data ( +CREATE TABLE IF NOT EXISTS userapi_account_datas ( -- The Matrix user ID localpart for this account localpart TEXT NOT NULL, -- The room ID for this data (empty string if not specific to a room) @@ -41,15 +41,15 @@ CREATE TABLE IF NOT EXISTS account_data ( ` const insertAccountDataSQL = ` - INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4) + INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4) ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content ` const selectAccountDataSQL = "" + - "SELECT room_id, type, content FROM account_data WHERE localpart = $1" + "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1" const selectAccountDataByTypeSQL = "" + - "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3" + "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3" type accountDataStatements struct { insertAccountDataStmt *sql.Stmt diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index afd1ad410..d3d0bbcd5 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -32,7 +32,7 @@ import ( const accountsSchema = ` -- Stores data about accounts. -CREATE TABLE IF NOT EXISTS account_accounts ( +CREATE TABLE IF NOT EXISTS userapi_accounts ( -- The Matrix user ID localpart for this account localpart TEXT NOT NULL PRIMARY KEY, -- When this account was first created, as a unix timestamp (ms resolution). @@ -51,22 +51,22 @@ CREATE TABLE IF NOT EXISTS account_accounts ( ` const insertAccountSQL = "" + - "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" + "INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" const updatePasswordSQL = "" + - "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" + "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2" const deactivateAccountSQL = "" + - "UPDATE account_accounts SET is_deactivated = TRUE WHERE localpart = $1" + "UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1" + "SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1" const selectPasswordHashSQL = "" + - "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE" + "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = FALSE" const selectNewNumericLocalpartSQL = "" + - "SELECT COALESCE(MAX(localpart::bigint), 0) FROM account_accounts WHERE localpart ~ '^[0-9]{1,}$'" + "SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$'" type accountsStatements struct { insertAccountStmt *sql.Stmt diff --git a/userapi/storage/postgres/deltas/20200929203058_is_active.go b/userapi/storage/postgres/deltas/20200929203058_is_active.go index 24f87e073..2c5cc2f58 100644 --- a/userapi/storage/postgres/deltas/20200929203058_is_active.go +++ b/userapi/storage/postgres/deltas/20200929203058_is_active.go @@ -7,7 +7,7 @@ import ( ) func UpIsActive(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS is_deactivated BOOLEAN DEFAULT FALSE;") + _, err := tx.ExecContext(ctx, "ALTER TABLE userapi_accounts ADD COLUMN IF NOT EXISTS is_deactivated BOOLEAN DEFAULT FALSE;") if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } @@ -15,7 +15,7 @@ func UpIsActive(ctx context.Context, tx *sql.Tx) error { } func DownIsActive(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts DROP COLUMN is_deactivated;") + _, err := tx.ExecContext(ctx, "ALTER TABLE userapi_accounts DROP COLUMN is_deactivated;") if err != nil { return fmt.Errorf("failed to execute downgrade: %w", err) } diff --git a/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go index edd3353f0..40e237027 100644 --- a/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go +++ b/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go @@ -8,9 +8,9 @@ import ( func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, ` -ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS last_seen_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM CURRENT_TIMESTAMP)*1000; -ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS ip TEXT; -ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`) +ALTER TABLE userapi_devices ADD COLUMN IF NOT EXISTS last_seen_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM CURRENT_TIMESTAMP)*1000; +ALTER TABLE userapi_devices ADD COLUMN IF NOT EXISTS ip TEXT; +ALTER TABLE userapi_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`) if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } @@ -19,9 +19,9 @@ ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`) func DownLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, ` - ALTER TABLE device_devices DROP COLUMN last_seen_ts; - ALTER TABLE device_devices DROP COLUMN ip; - ALTER TABLE device_devices DROP COLUMN user_agent;`) + ALTER TABLE userapi_devices DROP COLUMN last_seen_ts; + ALTER TABLE userapi_devices DROP COLUMN ip; + ALTER TABLE userapi_devices DROP COLUMN user_agent;`) if err != nil { return fmt.Errorf("failed to execute downgrade: %w", err) } diff --git a/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go b/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go index eb7c3a958..164847e51 100644 --- a/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go +++ b/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go @@ -9,10 +9,10 @@ import ( func UpAddAccountType(ctx context.Context, tx *sql.Tx) error { // initially set every account to useraccount, change appservice and guest accounts afterwards // (user = 1, guest = 2, admin = 3, appservice = 4) - _, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS account_type SMALLINT NOT NULL DEFAULT 1; -UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> ''; -UPDATE account_accounts SET account_type = 2 WHERE localpart ~ '^[0-9]+$'; -ALTER TABLE account_accounts ALTER COLUMN account_type DROP DEFAULT;`, + _, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts ADD COLUMN IF NOT EXISTS account_type SMALLINT NOT NULL DEFAULT 1; +UPDATE userapi_accounts SET account_type = 4 WHERE appservice_id <> ''; +UPDATE userapi_accounts SET account_type = 2 WHERE localpart ~ '^[0-9]+$'; +ALTER TABLE userapi_accounts ALTER COLUMN account_type DROP DEFAULT;`, ) if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) @@ -21,7 +21,7 @@ ALTER TABLE account_accounts ALTER COLUMN account_type DROP DEFAULT;`, } func DownAddAccountType(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts DROP COLUMN account_type;") + _, err := tx.ExecContext(ctx, "ALTER TABLE userapi_accounts DROP COLUMN account_type;") if err != nil { return fmt.Errorf("failed to execute downgrade: %w", err) } diff --git a/userapi/storage/postgres/deltas/2022080800000000_no_guests.go b/userapi/storage/postgres/deltas/2022080800000000_no_guests.go index cc6126aad..9985fd822 100644 --- a/userapi/storage/postgres/deltas/2022080800000000_no_guests.go +++ b/userapi/storage/postgres/deltas/2022080800000000_no_guests.go @@ -8,7 +8,7 @@ import ( func UpNoGuests(ctx context.Context, tx *sql.Tx) error { // AddAccountType introduced a bug where each user that had was registered as a regular user, but without user_id, became a guest. - _, err := tx.ExecContext(ctx, "UPDATE account_accounts SET account_type = 1 WHERE account_type = 2;") + _, err := tx.ExecContext(ctx, "UPDATE userapi_accounts SET account_type = 1 WHERE account_type = 2;") if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } diff --git a/userapi/storage/postgres/deltas/2022101711000000_rename_tables.go b/userapi/storage/postgres/deltas/2022101711000000_rename_tables.go new file mode 100644 index 000000000..1d73d0af4 --- /dev/null +++ b/userapi/storage/postgres/deltas/2022101711000000_rename_tables.go @@ -0,0 +1,102 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" +) + +var renameTableMappings = map[string]string{ + "account_accounts": "userapi_accounts", + "account_data": "userapi_account_datas", + "device_devices": "userapi_devices", + "account_e2e_room_keys": "userapi_key_backups", + "account_e2e_room_keys_versions": "userapi_key_backup_versions", + "login_tokens": "userapi_login_tokens", + "open_id_tokens": "userapi_openid_tokens", + "account_profiles": "userapi_profiles", + "account_threepid": "userapi_threepids", +} + +var renameSequenceMappings = map[string]string{ + "device_session_id_seq": "userapi_device_session_id_seq", + "account_e2e_room_keys_versions_seq": "userapi_key_backup_versions_seq", +} + +var renameIndicesMappings = map[string]string{ + "device_localpart_id_idx": "userapi_device_localpart_id_idx", + "e2e_room_keys_idx": "userapi_key_backups_idx", + "e2e_room_keys_versions_idx": "userapi_key_backups_versions_idx", + "account_e2e_room_keys_versions_idx": "userapi_key_backup_versions_idx", + "login_tokens_expiration_idx": "userapi_login_tokens_expiration_idx", + "account_threepid_localpart": "userapi_threepid_idx", +} + +// I know what you're thinking: you're wondering "why doesn't this use $1 +// and pass variadic parameters to ExecContext?" — the answer is because +// PostgreSQL doesn't expect the table name to be specified as a substituted +// argument in that way so it results in a syntax error in the query. + +func UpRenameTables(ctx context.Context, tx *sql.Tx) error { + for old, new := range renameTableMappings { + q := fmt.Sprintf( + "ALTER TABLE IF EXISTS %s RENAME TO %s;", + pq.QuoteIdentifier(old), pq.QuoteIdentifier(new), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("rename table %q to %q error: %w", old, new, err) + } + } + for old, new := range renameSequenceMappings { + q := fmt.Sprintf( + "ALTER SEQUENCE IF EXISTS %s RENAME TO %s;", + pq.QuoteIdentifier(old), pq.QuoteIdentifier(new), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("rename table %q to %q error: %w", old, new, err) + } + } + for old, new := range renameIndicesMappings { + q := fmt.Sprintf( + "ALTER INDEX IF EXISTS %s RENAME TO %s;", + pq.QuoteIdentifier(old), pq.QuoteIdentifier(new), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("rename table %q to %q error: %w", old, new, err) + } + } + return nil +} + +func DownRenameTables(ctx context.Context, tx *sql.Tx) error { + for old, new := range renameTableMappings { + q := fmt.Sprintf( + "ALTER TABLE IF EXISTS %s RENAME TO %s;", + pq.QuoteIdentifier(new), pq.QuoteIdentifier(old), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("rename table %q to %q error: %w", new, old, err) + } + } + for old, new := range renameSequenceMappings { + q := fmt.Sprintf( + "ALTER SEQUENCE IF EXISTS %s RENAME TO %s;", + pq.QuoteIdentifier(new), pq.QuoteIdentifier(old), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("rename table %q to %q error: %w", new, old, err) + } + } + for old, new := range renameIndicesMappings { + q := fmt.Sprintf( + "ALTER INDEX IF EXISTS %s RENAME TO %s;", + pq.QuoteIdentifier(new), pq.QuoteIdentifier(old), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("rename table %q to %q error: %w", new, old, err) + } + } + return nil +} diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index f65681aae..8b7fbd6cf 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -31,10 +31,10 @@ import ( const devicesSchema = ` -- This sequence is used for automatic allocation of session_id. -CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1; +CREATE SEQUENCE IF NOT EXISTS userapi_device_session_id_seq START 1; -- Stores data about devices. -CREATE TABLE IF NOT EXISTS device_devices ( +CREATE TABLE IF NOT EXISTS userapi_devices ( -- The access token granted to this device. This has to be the primary key -- so we can distinguish which device is making a given request. access_token TEXT NOT NULL PRIMARY KEY, @@ -42,7 +42,7 @@ CREATE TABLE IF NOT EXISTS device_devices ( -- This can be used as a secure substitution of the access token in situations -- where data is associated with access tokens (e.g. transaction storage), -- so we don't have to store users' access tokens everywhere. - session_id BIGINT NOT NULL DEFAULT nextval('device_session_id_seq'), + session_id BIGINT NOT NULL DEFAULT nextval('userapi_device_session_id_seq'), -- The device identifier. This only needs to uniquely identify a device for a given user, not globally. -- access_tokens will be clobbered based on the device ID for a user. device_id TEXT NOT NULL, @@ -65,39 +65,39 @@ CREATE TABLE IF NOT EXISTS device_devices ( ); -- Device IDs must be unique for a given user. -CREATE UNIQUE INDEX IF NOT EXISTS device_localpart_id_idx ON device_devices(localpart, device_id); +CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, device_id); ` const insertDeviceSQL = "" + - "INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + + "INSERT INTO userapi_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + " RETURNING session_id" const selectDeviceByTokenSQL = "" + - "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" + "SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1" const selectDeviceByIDSQL = "" + - "SELECT display_name, last_seen_ts, ip FROM device_devices WHERE localpart = $1 and device_id = $2" + "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" const updateDeviceNameSQL = "" + - "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" + "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" const deleteDeviceSQL = "" + - "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" + "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2" const deleteDevicesByLocalpartSQL = "" + - "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2" + "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2" const deleteDevicesSQL = "" + - "DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)" + "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id = ANY($2)" const selectDevicesByIDSQL = "" + - "SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC" + "SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC" const updateDeviceLastSeen = "" + - "UPDATE device_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" + "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" type devicesStatements struct { insertDeviceStmt *sql.Stmt diff --git a/userapi/storage/postgres/key_backup_table.go b/userapi/storage/postgres/key_backup_table.go index ac0e80617..7b58f7bae 100644 --- a/userapi/storage/postgres/key_backup_table.go +++ b/userapi/storage/postgres/key_backup_table.go @@ -26,7 +26,7 @@ import ( ) const keyBackupTableSchema = ` -CREATE TABLE IF NOT EXISTS account_e2e_room_keys ( +CREATE TABLE IF NOT EXISTS userapi_key_backups ( user_id TEXT NOT NULL, room_id TEXT NOT NULL, session_id TEXT NOT NULL, @@ -37,31 +37,31 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys ( is_verified BOOLEAN NOT NULL, session_data TEXT NOT NULL ); -CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version); -CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version); +CREATE UNIQUE INDEX IF NOT EXISTS userapi_key_backups_idx ON userapi_key_backups(user_id, room_id, session_id, version); +CREATE INDEX IF NOT EXISTS userapi_key_backups_versions_idx ON userapi_key_backups(user_id, version); ` const insertBackupKeySQL = "" + - "INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " + + "INSERT INTO userapi_key_backups(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" const updateBackupKeySQL = "" + - "UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " + + "UPDATE userapi_key_backups SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " + "WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8" const countKeysSQL = "" + - "SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2" + "SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2" const selectKeysSQL = "" + - "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2" const selectKeysByRoomIDSQL = "" + - "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2 AND room_id = $3" const selectKeysByRoomIDAndSessionIDSQL = "" + - "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4" type keyBackupStatements struct { diff --git a/userapi/storage/postgres/key_backup_version_table.go b/userapi/storage/postgres/key_backup_version_table.go index e78e4cd51..67c5e5481 100644 --- a/userapi/storage/postgres/key_backup_version_table.go +++ b/userapi/storage/postgres/key_backup_version_table.go @@ -26,40 +26,40 @@ import ( ) const keyBackupVersionTableSchema = ` -CREATE SEQUENCE IF NOT EXISTS account_e2e_room_keys_versions_seq; +CREATE SEQUENCE IF NOT EXISTS userapi_key_backup_versions_seq; -- the metadata for each generation of encrypted e2e session backups -CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions ( +CREATE TABLE IF NOT EXISTS userapi_key_backup_versions ( user_id TEXT NOT NULL, -- this means no 2 users will ever have the same version of e2e session backups which strictly -- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1. - version BIGINT DEFAULT nextval('account_e2e_room_keys_versions_seq'), + version BIGINT DEFAULT nextval('userapi_key_backup_versions_seq'), algorithm TEXT NOT NULL, auth_data TEXT NOT NULL, etag TEXT NOT NULL, deleted SMALLINT DEFAULT 0 NOT NULL ); -CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); +CREATE UNIQUE INDEX IF NOT EXISTS userapi_key_backup_versions_idx ON userapi_key_backup_versions(user_id, version); ` const insertKeyBackupSQL = "" + - "INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version" + "INSERT INTO userapi_key_backup_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version" const updateKeyBackupAuthDataSQL = "" + - "UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" + "UPDATE userapi_key_backup_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" const updateKeyBackupETagSQL = "" + - "UPDATE account_e2e_room_keys_versions SET etag = $1 WHERE user_id = $2 AND version = $3" + "UPDATE userapi_key_backup_versions SET etag = $1 WHERE user_id = $2 AND version = $3" const deleteKeyBackupSQL = "" + - "UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2" + "UPDATE userapi_key_backup_versions SET deleted=1 WHERE user_id = $1 AND version = $2" const selectKeyBackupSQL = "" + - "SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2" + "SELECT algorithm, auth_data, etag, deleted FROM userapi_key_backup_versions WHERE user_id = $1 AND version = $2" const selectLatestVersionSQL = "" + - "SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1" + "SELECT MAX(version) FROM userapi_key_backup_versions WHERE user_id = $1" type keyBackupVersionStatements struct { insertKeyBackupStmt *sql.Stmt diff --git a/userapi/storage/postgres/logintoken_table.go b/userapi/storage/postgres/logintoken_table.go index 4de96f839..44c6ca4ae 100644 --- a/userapi/storage/postgres/logintoken_table.go +++ b/userapi/storage/postgres/logintoken_table.go @@ -26,7 +26,7 @@ import ( ) const loginTokenSchema = ` -CREATE TABLE IF NOT EXISTS login_tokens ( +CREATE TABLE IF NOT EXISTS userapi_login_tokens ( -- The random value of the token issued to a user token TEXT NOT NULL PRIMARY KEY, -- When the token expires @@ -37,17 +37,17 @@ CREATE TABLE IF NOT EXISTS login_tokens ( ); -- This index allows efficient garbage collection of expired tokens. -CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); +CREATE INDEX IF NOT EXISTS userapi_login_tokens_expiration_idx ON userapi_login_tokens(token_expires_at); ` const insertLoginTokenSQL = "" + - "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)" + "INSERT INTO userapi_login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)" const deleteLoginTokenSQL = "" + - "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2" + "DELETE FROM userapi_login_tokens WHERE token = $1 OR token_expires_at <= $2" const selectLoginTokenSQL = "" + - "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2" + "SELECT user_id FROM userapi_login_tokens WHERE token = $1 AND token_expires_at > $2" type loginTokenStatements struct { insertStmt *sql.Stmt @@ -78,7 +78,7 @@ func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx // deleteByToken removes the named token. // // As a simple way to garbage-collect stale tokens, we also remove all expired tokens. -// The login_tokens_expiration_idx index should make that efficient. +// The userapi_login_tokens_expiration_idx index should make that efficient. func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error { stmt := sqlutil.TxStmt(txn, s.deleteStmt) res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) diff --git a/userapi/storage/postgres/openid_table.go b/userapi/storage/postgres/openid_table.go index 29c3ddcb4..06ae30d08 100644 --- a/userapi/storage/postgres/openid_table.go +++ b/userapi/storage/postgres/openid_table.go @@ -13,7 +13,7 @@ import ( const openIDTokenSchema = ` -- Stores data about openid tokens issued for accounts. -CREATE TABLE IF NOT EXISTS open_id_tokens ( +CREATE TABLE IF NOT EXISTS userapi_openid_tokens ( -- The value of the token issued to a user token TEXT NOT NULL PRIMARY KEY, -- The Matrix user ID for this account @@ -24,10 +24,10 @@ CREATE TABLE IF NOT EXISTS open_id_tokens ( ` const insertOpenIDTokenSQL = "" + - "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" + "INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" const selectOpenIDTokenSQL = "" + - "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" + "SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" type openIDTokenStatements struct { insertTokenStmt *sql.Stmt diff --git a/userapi/storage/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go index 6d336eb8e..2753b23d9 100644 --- a/userapi/storage/postgres/profile_table.go +++ b/userapi/storage/postgres/profile_table.go @@ -27,7 +27,7 @@ import ( const profilesSchema = ` -- Stores data about accounts profiles. -CREATE TABLE IF NOT EXISTS account_profiles ( +CREATE TABLE IF NOT EXISTS userapi_profiles ( -- The Matrix user ID localpart for this account localpart TEXT NOT NULL PRIMARY KEY, -- The display name for this account @@ -38,19 +38,27 @@ CREATE TABLE IF NOT EXISTS account_profiles ( ` const insertProfileSQL = "" + - "INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" + "INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" const selectProfileByLocalpartSQL = "" + - "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1" + "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" const setAvatarURLSQL = "" + - "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2" + "UPDATE userapi_profiles AS new" + + " SET avatar_url = $1" + + " FROM userapi_profiles AS old" + + " WHERE new.localpart = $2" + + " RETURNING new.display_name, old.avatar_url <> new.avatar_url" const setDisplayNameSQL = "" + - "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" + "UPDATE userapi_profiles AS new" + + " SET display_name = $1" + + " FROM userapi_profiles AS old" + + " WHERE new.localpart = $2" + + " RETURNING new.avatar_url, old.display_name <> new.display_name" const selectProfilesBySearchSQL = "" + - "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" + "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" type profilesStatements struct { serverNoticesLocalpart string @@ -100,16 +108,28 @@ func (s *profilesStatements) SelectProfileByLocalpart( func (s *profilesStatements) SetAvatarURL( ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, -) (err error) { - _, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart) - return +) (*authtypes.Profile, bool, error) { + profile := &authtypes.Profile{ + Localpart: localpart, + AvatarURL: avatarURL, + } + var changed bool + stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) + err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName, &changed) + return profile, changed, err } func (s *profilesStatements) SetDisplayName( ctx context.Context, txn *sql.Tx, localpart string, displayName string, -) (err error) { - _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) - return +) (*authtypes.Profile, bool, error) { + profile := &authtypes.Profile{ + Localpart: localpart, + DisplayName: displayName, + } + var changed bool + stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) + err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL, &changed) + return profile, changed, err } func (s *profilesStatements) SelectProfilesBySearch( diff --git a/userapi/storage/postgres/stats_table.go b/userapi/storage/postgres/stats_table.go index c0b317503..20eb0bf46 100644 --- a/userapi/storage/postgres/stats_table.go +++ b/userapi/storage/postgres/stats_table.go @@ -45,7 +45,7 @@ CREATE INDEX IF NOT EXISTS userapi_daily_visits_localpart_timestamp_idx ON usera const countUsersLastSeenAfterSQL = "" + "SELECT COUNT(*) FROM (" + - " SELECT localpart FROM device_devices WHERE last_seen_ts > $1 " + + " SELECT localpart FROM userapi_devices WHERE last_seen_ts > $1 " + " GROUP BY localpart" + " ) u" @@ -62,7 +62,7 @@ R30Users counts the number of 30 day retained users, defined as: const countR30UsersSQL = ` SELECT platform, COUNT(*) FROM ( SELECT users.localpart, platform, users.created_ts, MAX(uip.last_seen_ts) - FROM account_accounts users + FROM userapi_accounts users INNER JOIN (SELECT localpart, last_seen_ts, @@ -75,7 +75,7 @@ SELECT platform, COUNT(*) FROM ( ELSE 'unknown' END AS platform - FROM device_devices + FROM userapi_devices ) uip ON users.localpart = uip.localpart AND users.account_type <> 4 @@ -121,7 +121,7 @@ GROUP BY client_type ` const countUserByAccountTypeSQL = ` -SELECT COUNT(*) FROM account_accounts WHERE account_type = ANY($1) +SELECT COUNT(*) FROM userapi_accounts WHERE account_type = ANY($1) ` // $1 = All non guest AccountType IDs @@ -134,7 +134,7 @@ SELECT user_type, COUNT(*) AS count FROM ( WHEN account_type = $2 AND appservice_id IS NULL THEN 'guest' WHEN account_type = ANY($1) AND appservice_id IS NOT NULL THEN 'bridged' END AS user_type - FROM account_accounts + FROM userapi_accounts WHERE created_ts > $3 ) AS t GROUP BY user_type ` @@ -143,14 +143,14 @@ SELECT user_type, COUNT(*) AS count FROM ( const updateUserDailyVisitsSQL = ` INSERT INTO userapi_daily_visits(localpart, device_id, timestamp, user_agent) SELECT u.localpart, u.device_id, $1, MAX(u.user_agent) - FROM device_devices AS u + FROM userapi_devices AS u LEFT JOIN ( SELECT localpart, device_id, timestamp FROM userapi_daily_visits WHERE timestamp = $1 ) udv ON u.localpart = udv.localpart AND u.device_id = udv.device_id - INNER JOIN device_devices d ON d.localpart = u.localpart - INNER JOIN account_accounts a ON a.localpart = u.localpart + INNER JOIN userapi_devices d ON d.localpart = u.localpart + INNER JOIN userapi_accounts a ON a.localpart = u.localpart WHERE $2 <= d.last_seen_ts AND d.last_seen_ts < $3 AND a.account_type in (1, 3) GROUP BY u.localpart, u.device_id diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index 7d3b9b6a5..c059e3e60 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/userapi/storage/shared" // Import the postgres database driver. @@ -36,6 +37,16 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, return nil, err } + m := sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: "userapi: rename tables", + Up: deltas.UpRenameTables, + Down: deltas.DownRenameTables, + }) + if err = m.Up(base.Context()); err != nil { + return nil, err + } + accountDataTable, err := NewPostgresAccountDataTable(db) if err != nil { return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err) diff --git a/userapi/storage/postgres/threepid_table.go b/userapi/storage/postgres/threepid_table.go index 63c08d61f..11af76161 100644 --- a/userapi/storage/postgres/threepid_table.go +++ b/userapi/storage/postgres/threepid_table.go @@ -26,7 +26,7 @@ import ( const threepidSchema = ` -- Stores data about third party identifiers -CREATE TABLE IF NOT EXISTS account_threepid ( +CREATE TABLE IF NOT EXISTS userapi_threepids ( -- The third party identifier threepid TEXT NOT NULL, -- The 3PID medium @@ -37,20 +37,20 @@ CREATE TABLE IF NOT EXISTS account_threepid ( PRIMARY KEY(threepid, medium) ); -CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart); +CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart); ` const selectLocalpartForThreePIDSQL = "" + - "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2" + "SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2" const selectThreePIDsForLocalpartSQL = "" + - "SELECT threepid, medium FROM account_threepid WHERE localpart = $1" + "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1" const insertThreePIDSQL = "" + - "INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)" + "INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)" const deleteThreePIDSQL = "" + - "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2" + "DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2" type threepidStatements struct { selectLocalpartForThreePIDStmt *sql.Stmt diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 294f96918..912c6639a 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -96,20 +96,24 @@ func (d *Database) GetProfileByLocalpart( // localpart. Returns an error if something went wrong with the SQL query func (d *Database) SetAvatarURL( ctx context.Context, localpart string, avatarURL string, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL) +) (profile *authtypes.Profile, changed bool, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL) + return err }) + return } // SetDisplayName updates the display name of the profile associated with the given // localpart. Returns an error if something went wrong with the SQL query func (d *Database) SetDisplayName( ctx context.Context, localpart string, displayName string, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.Profiles.SetDisplayName(ctx, txn, localpart, displayName) +) (profile *authtypes.Profile, changed bool, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName) + return err }) + return } // SetPassword sets the account password to the given hash. diff --git a/userapi/storage/sqlite3/account_data_table.go b/userapi/storage/sqlite3/account_data_table.go index cfd8568a9..af12decb3 100644 --- a/userapi/storage/sqlite3/account_data_table.go +++ b/userapi/storage/sqlite3/account_data_table.go @@ -25,7 +25,7 @@ import ( const accountDataSchema = ` -- Stores data about accounts data. -CREATE TABLE IF NOT EXISTS account_data ( +CREATE TABLE IF NOT EXISTS userapi_account_datas ( -- The Matrix user ID localpart for this account localpart TEXT NOT NULL, -- The room ID for this data (empty string if not specific to a room) @@ -40,15 +40,15 @@ CREATE TABLE IF NOT EXISTS account_data ( ` const insertAccountDataSQL = ` - INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4) + INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4) ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4 ` const selectAccountDataSQL = "" + - "SELECT room_id, type, content FROM account_data WHERE localpart = $1" + "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1" const selectAccountDataByTypeSQL = "" + - "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3" + "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3" type accountDataStatements struct { db *sql.DB diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go index 484e90056..671c1aa04 100644 --- a/userapi/storage/sqlite3/accounts_table.go +++ b/userapi/storage/sqlite3/accounts_table.go @@ -32,7 +32,7 @@ import ( const accountsSchema = ` -- Stores data about accounts. -CREATE TABLE IF NOT EXISTS account_accounts ( +CREATE TABLE IF NOT EXISTS userapi_accounts ( -- The Matrix user ID localpart for this account localpart TEXT NOT NULL PRIMARY KEY, -- When this account was first created, as a unix timestamp (ms resolution). @@ -51,22 +51,22 @@ CREATE TABLE IF NOT EXISTS account_accounts ( ` const insertAccountSQL = "" + - "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" + "INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" const updatePasswordSQL = "" + - "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" + "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2" const deactivateAccountSQL = "" + - "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1" + "UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1" + "SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1" const selectPasswordHashSQL = "" + - "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0" + "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = 0" const selectNewNumericLocalpartSQL = "" + - "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM account_accounts WHERE CAST(localpart AS INT) <> 0" + "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0" type accountsStatements struct { db *sql.DB diff --git a/userapi/storage/sqlite3/deltas/20200929203058_is_active.go b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go index e25efc695..9158cb365 100644 --- a/userapi/storage/sqlite3/deltas/20200929203058_is_active.go +++ b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go @@ -8,8 +8,8 @@ import ( func UpIsActive(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, ` - ALTER TABLE account_accounts RENAME TO account_accounts_tmp; -CREATE TABLE account_accounts ( + ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp; +CREATE TABLE userapi_accounts ( localpart TEXT NOT NULL PRIMARY KEY, created_ts BIGINT NOT NULL, password_hash TEXT, @@ -17,13 +17,13 @@ CREATE TABLE account_accounts ( is_deactivated BOOLEAN DEFAULT 0 ); INSERT - INTO account_accounts ( + INTO userapi_accounts ( localpart, created_ts, password_hash, appservice_id ) SELECT localpart, created_ts, password_hash, appservice_id - FROM account_accounts_tmp + FROM userapi_accounts_tmp ; -DROP TABLE account_accounts_tmp;`) +DROP TABLE userapi_accounts_tmp;`) if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } @@ -32,21 +32,21 @@ DROP TABLE account_accounts_tmp;`) func DownIsActive(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, ` - ALTER TABLE account_accounts RENAME TO account_accounts_tmp; -CREATE TABLE account_accounts ( + ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp; +CREATE TABLE userapi_accounts ( localpart TEXT NOT NULL PRIMARY KEY, created_ts BIGINT NOT NULL, password_hash TEXT, appservice_id TEXT ); INSERT - INTO account_accounts ( + INTO userapi_accounts ( localpart, created_ts, password_hash, appservice_id ) SELECT localpart, created_ts, password_hash, appservice_id - FROM account_accounts_tmp + FROM userapi_accounts_tmp ; -DROP TABLE account_accounts_tmp;`) +DROP TABLE userapi_accounts_tmp;`) if err != nil { return fmt.Errorf("failed to execute downgrade: %w", err) } diff --git a/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go index 7f7e95d2d..a9224db6b 100644 --- a/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go +++ b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go @@ -8,8 +8,8 @@ import ( func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, ` - ALTER TABLE device_devices RENAME TO device_devices_tmp; - CREATE TABLE device_devices ( + ALTER TABLE userapi_devices RENAME TO userapi_devices_tmp; + CREATE TABLE userapi_devices ( access_token TEXT PRIMARY KEY, session_id INTEGER, device_id TEXT , @@ -22,12 +22,12 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { UNIQUE (localpart, device_id) ); INSERT - INTO device_devices ( + INTO userapi_devices ( access_token, session_id, device_id, localpart, created_ts, display_name, last_seen_ts, ip, user_agent ) SELECT access_token, session_id, device_id, localpart, created_ts, display_name, created_ts, '', '' - FROM device_devices_tmp; - DROP TABLE device_devices_tmp;`) + FROM userapi_devices_tmp; + DROP TABLE userapi_devices_tmp;`) if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } @@ -36,8 +36,8 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { func DownLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, ` -ALTER TABLE device_devices RENAME TO device_devices_tmp; -CREATE TABLE IF NOT EXISTS device_devices ( +ALTER TABLE userapi_devices RENAME TO userapi_devices_tmp; +CREATE TABLE IF NOT EXISTS userapi_devices ( access_token TEXT PRIMARY KEY, session_id INTEGER, device_id TEXT , @@ -47,12 +47,12 @@ CREATE TABLE IF NOT EXISTS device_devices ( UNIQUE (localpart, device_id) ); INSERT -INTO device_devices ( +INTO userapi_devices ( access_token, session_id, device_id, localpart, created_ts, display_name ) SELECT access_token, session_id, device_id, localpart, created_ts, display_name -FROM device_devices_tmp; -DROP TABLE device_devices_tmp;`) +FROM userapi_devices_tmp; +DROP TABLE userapi_devices_tmp;`) if err != nil { return fmt.Errorf("failed to execute downgrade: %w", err) } diff --git a/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go index 46532698c..230bc1433 100644 --- a/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go +++ b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go @@ -9,8 +9,8 @@ import ( func UpAddAccountType(ctx context.Context, tx *sql.Tx) error { // initially set every account to useraccount, change appservice and guest accounts afterwards // (user = 1, guest = 2, admin = 3, appservice = 4) - _, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts RENAME TO account_accounts_tmp; -CREATE TABLE account_accounts ( + _, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp; +CREATE TABLE userapi_accounts ( localpart TEXT NOT NULL PRIMARY KEY, created_ts BIGINT NOT NULL, password_hash TEXT, @@ -19,15 +19,15 @@ CREATE TABLE account_accounts ( account_type INTEGER NOT NULL ); INSERT - INTO account_accounts ( + INTO userapi_accounts ( localpart, created_ts, password_hash, appservice_id, account_type ) SELECT localpart, created_ts, password_hash, appservice_id, 1 - FROM account_accounts_tmp + FROM userapi_accounts_tmp ; -UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> ''; -UPDATE account_accounts SET account_type = 2 WHERE localpart GLOB '[0-9]*'; -DROP TABLE account_accounts_tmp;`) +UPDATE userapi_accounts SET account_type = 4 WHERE appservice_id <> ''; +UPDATE userapi_accounts SET account_type = 2 WHERE localpart GLOB '[0-9]*'; +DROP TABLE userapi_accounts_tmp;`) if err != nil { return fmt.Errorf("failed to add column: %w", err) } @@ -35,7 +35,7 @@ DROP TABLE account_accounts_tmp;`) } func DownAddAccountType(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts DROP COLUMN account_type;`) + _, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts DROP COLUMN account_type;`) if err != nil { return fmt.Errorf("failed to execute downgrade: %w", err) } diff --git a/userapi/storage/sqlite3/deltas/2022101711000000_rename_tables.go b/userapi/storage/sqlite3/deltas/2022101711000000_rename_tables.go new file mode 100644 index 000000000..4ca1dc475 --- /dev/null +++ b/userapi/storage/sqlite3/deltas/2022101711000000_rename_tables.go @@ -0,0 +1,109 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" + "strings" +) + +var renameTableMappings = map[string]string{ + "account_accounts": "userapi_accounts", + "account_data": "userapi_account_datas", + "device_devices": "userapi_devices", + "account_e2e_room_keys": "userapi_key_backups", + "account_e2e_room_keys_versions": "userapi_key_backup_versions", + "login_tokens": "userapi_login_tokens", + "open_id_tokens": "userapi_openid_tokens", + "account_profiles": "userapi_profiles", + "account_threepid": "userapi_threepids", +} + +var renameIndicesMappings = map[string]string{ + "device_localpart_id_idx": "userapi_device_localpart_id_idx", + "e2e_room_keys_idx": "userapi_key_backups_idx", + "e2e_room_keys_versions_idx": "userapi_key_backups_versions_idx", + "account_e2e_room_keys_versions_idx": "userapi_key_backup_versions_idx", + "login_tokens_expiration_idx": "userapi_login_tokens_expiration_idx", + "account_threepid_localpart": "userapi_threepid_idx", +} + +func UpRenameTables(ctx context.Context, tx *sql.Tx) error { + for old, new := range renameTableMappings { + // SQLite has no "IF EXISTS" so check if the table exists. + var name string + if err := tx.QueryRowContext( + ctx, "SELECT name FROM sqlite_schema WHERE type = 'table' AND name = $1;", old, + ).Scan(&name); err != nil { + if err == sql.ErrNoRows { + continue + } + return err + } + q := fmt.Sprintf( + "ALTER TABLE %s RENAME TO %s;", old, new, + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("rename table %q to %q error: %w", old, new, err) + } + } + for old, new := range renameIndicesMappings { + var query string + if err := tx.QueryRowContext( + ctx, "SELECT sql FROM sqlite_schema WHERE type = 'index' AND name = $1;", old, + ).Scan(&query); err != nil { + if err == sql.ErrNoRows { + continue + } + return err + } + query = strings.Replace(query, old, new, 1) + if _, err := tx.ExecContext(ctx, fmt.Sprintf("DROP INDEX %s;", old)); err != nil { + return fmt.Errorf("drop index %q to %q error: %w", old, new, err) + } + if _, err := tx.ExecContext(ctx, query); err != nil { + return fmt.Errorf("recreate index %q to %q error: %w", old, new, err) + } + } + return nil +} + +func DownRenameTables(ctx context.Context, tx *sql.Tx) error { + for old, new := range renameTableMappings { + // SQLite has no "IF EXISTS" so check if the table exists. + var name string + if err := tx.QueryRowContext( + ctx, "SELECT name FROM sqlite_schema WHERE type = 'table' AND name = $1;", new, + ).Scan(&name); err != nil { + if err == sql.ErrNoRows { + continue + } + return err + } + q := fmt.Sprintf( + "ALTER TABLE %s RENAME TO %s;", new, old, + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("rename table %q to %q error: %w", new, old, err) + } + } + for old, new := range renameIndicesMappings { + var query string + if err := tx.QueryRowContext( + ctx, "SELECT sql FROM sqlite_schema WHERE type = 'index' AND name = $1;", new, + ).Scan(&query); err != nil { + if err == sql.ErrNoRows { + continue + } + return err + } + query = strings.Replace(query, new, old, 1) + if _, err := tx.ExecContext(ctx, fmt.Sprintf("DROP INDEX %s;", new)); err != nil { + return fmt.Errorf("drop index %q to %q error: %w", new, old, err) + } + if _, err := tx.ExecContext(ctx, query); err != nil { + return fmt.Errorf("recreate index %q to %q error: %w", new, old, err) + } + } + return nil +} diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index 27a7524d6..e53a08062 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -35,7 +35,7 @@ const devicesSchema = ` -- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1; -- Stores data about devices. -CREATE TABLE IF NOT EXISTS device_devices ( +CREATE TABLE IF NOT EXISTS userapi_devices ( access_token TEXT PRIMARY KEY, session_id INTEGER, device_id TEXT , @@ -51,38 +51,38 @@ CREATE TABLE IF NOT EXISTS device_devices ( ` const insertDeviceSQL = "" + - "INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" + + "INSERT INTO userapi_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" const selectDevicesCountSQL = "" + - "SELECT COUNT(access_token) FROM device_devices" + "SELECT COUNT(access_token) FROM userapi_devices" const selectDeviceByTokenSQL = "" + - "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" + "SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1" const selectDeviceByIDSQL = "" + - "SELECT display_name, last_seen_ts, ip FROM device_devices WHERE localpart = $1 and device_id = $2" + "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" const updateDeviceNameSQL = "" + - "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" + "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" const deleteDeviceSQL = "" + - "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" + "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2" const deleteDevicesByLocalpartSQL = "" + - "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2" + "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2" const deleteDevicesSQL = "" + - "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" + "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id IN ($2)" const selectDevicesByIDSQL = "" + - "SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC" + "SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC" const updateDeviceLastSeen = "" + - "UPDATE device_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" + "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" type devicesStatements struct { db *sql.DB diff --git a/userapi/storage/sqlite3/key_backup_table.go b/userapi/storage/sqlite3/key_backup_table.go index 81726edf9..7883ffb19 100644 --- a/userapi/storage/sqlite3/key_backup_table.go +++ b/userapi/storage/sqlite3/key_backup_table.go @@ -26,7 +26,7 @@ import ( ) const keyBackupTableSchema = ` -CREATE TABLE IF NOT EXISTS account_e2e_room_keys ( +CREATE TABLE IF NOT EXISTS userapi_key_backups ( user_id TEXT NOT NULL, room_id TEXT NOT NULL, session_id TEXT NOT NULL, @@ -37,31 +37,31 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys ( is_verified BOOLEAN NOT NULL, session_data TEXT NOT NULL ); -CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version); -CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version); +CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON userapi_key_backups(user_id, room_id, session_id, version); +CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON userapi_key_backups(user_id, version); ` const insertBackupKeySQL = "" + - "INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " + + "INSERT INTO userapi_key_backups(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" const updateBackupKeySQL = "" + - "UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " + + "UPDATE userapi_key_backups SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " + "WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8" const countKeysSQL = "" + - "SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2" + "SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2" const selectKeysSQL = "" + - "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2" const selectKeysByRoomIDSQL = "" + - "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2 AND room_id = $3" const selectKeysByRoomIDAndSessionIDSQL = "" + - "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4" type keyBackupStatements struct { diff --git a/userapi/storage/sqlite3/key_backup_version_table.go b/userapi/storage/sqlite3/key_backup_version_table.go index e85e6f08b..37bc13ed1 100644 --- a/userapi/storage/sqlite3/key_backup_version_table.go +++ b/userapi/storage/sqlite3/key_backup_version_table.go @@ -27,7 +27,7 @@ import ( const keyBackupVersionTableSchema = ` -- the metadata for each generation of encrypted e2e session backups -CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions ( +CREATE TABLE IF NOT EXISTS userapi_key_backup_versions ( user_id TEXT NOT NULL, -- this means no 2 users will ever have the same version of e2e session backups which strictly -- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1. @@ -38,26 +38,26 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions ( deleted INTEGER DEFAULT 0 NOT NULL ); -CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); +CREATE UNIQUE INDEX IF NOT EXISTS userapi_key_backup_versions_idx ON userapi_key_backup_versions(user_id, version); ` const insertKeyBackupSQL = "" + - "INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version" + "INSERT INTO userapi_key_backup_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version" const updateKeyBackupAuthDataSQL = "" + - "UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" + "UPDATE userapi_key_backup_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" const updateKeyBackupETagSQL = "" + - "UPDATE account_e2e_room_keys_versions SET etag = $1 WHERE user_id = $2 AND version = $3" + "UPDATE userapi_key_backup_versions SET etag = $1 WHERE user_id = $2 AND version = $3" const deleteKeyBackupSQL = "" + - "UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2" + "UPDATE userapi_key_backup_versions SET deleted=1 WHERE user_id = $1 AND version = $2" const selectKeyBackupSQL = "" + - "SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2" + "SELECT algorithm, auth_data, etag, deleted FROM userapi_key_backup_versions WHERE user_id = $1 AND version = $2" const selectLatestVersionSQL = "" + - "SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1" + "SELECT MAX(version) FROM userapi_key_backup_versions WHERE user_id = $1" type keyBackupVersionStatements struct { insertKeyBackupStmt *sql.Stmt diff --git a/userapi/storage/sqlite3/logintoken_table.go b/userapi/storage/sqlite3/logintoken_table.go index 78d42029a..2abdcb95e 100644 --- a/userapi/storage/sqlite3/logintoken_table.go +++ b/userapi/storage/sqlite3/logintoken_table.go @@ -32,7 +32,7 @@ type loginTokenStatements struct { } const loginTokenSchema = ` -CREATE TABLE IF NOT EXISTS login_tokens ( +CREATE TABLE IF NOT EXISTS userapi_login_tokens ( -- The random value of the token issued to a user token TEXT NOT NULL PRIMARY KEY, -- When the token expires @@ -43,17 +43,17 @@ CREATE TABLE IF NOT EXISTS login_tokens ( ); -- This index allows efficient garbage collection of expired tokens. -CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); +CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON userapi_login_tokens(token_expires_at); ` const insertLoginTokenSQL = "" + - "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)" + "INSERT INTO userapi_login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)" const deleteLoginTokenSQL = "" + - "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2" + "DELETE FROM userapi_login_tokens WHERE token = $1 OR token_expires_at <= $2" const selectLoginTokenSQL = "" + - "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2" + "SELECT user_id FROM userapi_login_tokens WHERE token = $1 AND token_expires_at > $2" func NewSQLiteLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) { s := &loginTokenStatements{} @@ -78,7 +78,7 @@ func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx // deleteByToken removes the named token. // // As a simple way to garbage-collect stale tokens, we also remove all expired tokens. -// The login_tokens_expiration_idx index should make that efficient. +// The userapi_login_tokens_expiration_idx index should make that efficient. func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error { stmt := sqlutil.TxStmt(txn, s.deleteStmt) res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) diff --git a/userapi/storage/sqlite3/openid_table.go b/userapi/storage/sqlite3/openid_table.go index d6090e0da..875f1a9a5 100644 --- a/userapi/storage/sqlite3/openid_table.go +++ b/userapi/storage/sqlite3/openid_table.go @@ -13,7 +13,7 @@ import ( const openIDTokenSchema = ` -- Stores data about accounts. -CREATE TABLE IF NOT EXISTS open_id_tokens ( +CREATE TABLE IF NOT EXISTS userapi_openid_tokens ( -- The value of the token issued to a user token TEXT NOT NULL PRIMARY KEY, -- The Matrix user ID for this account @@ -24,10 +24,10 @@ CREATE TABLE IF NOT EXISTS open_id_tokens ( ` const insertOpenIDTokenSQL = "" + - "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" + "INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" const selectOpenIDTokenSQL = "" + - "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" + "SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" type openIDTokenStatements struct { db *sql.DB diff --git a/userapi/storage/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go index 3050ff4b5..b6130a1e3 100644 --- a/userapi/storage/sqlite3/profile_table.go +++ b/userapi/storage/sqlite3/profile_table.go @@ -27,7 +27,7 @@ import ( const profilesSchema = ` -- Stores data about accounts profiles. -CREATE TABLE IF NOT EXISTS account_profiles ( +CREATE TABLE IF NOT EXISTS userapi_profiles ( -- The Matrix user ID localpart for this account localpart TEXT NOT NULL PRIMARY KEY, -- The display name for this account @@ -38,19 +38,21 @@ CREATE TABLE IF NOT EXISTS account_profiles ( ` const insertProfileSQL = "" + - "INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" + "INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" const selectProfileByLocalpartSQL = "" + - "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1" + "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" const setAvatarURLSQL = "" + - "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2" + "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" + + " RETURNING display_name" const setDisplayNameSQL = "" + - "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" + "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" + + " RETURNING avatar_url" const selectProfilesBySearchSQL = "" + - "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" + "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" type profilesStatements struct { db *sql.DB @@ -102,18 +104,40 @@ func (s *profilesStatements) SelectProfileByLocalpart( func (s *profilesStatements) SetAvatarURL( ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, -) (err error) { +) (*authtypes.Profile, bool, error) { + profile := &authtypes.Profile{ + Localpart: localpart, + AvatarURL: avatarURL, + } + old, err := s.SelectProfileByLocalpart(ctx, localpart) + if err != nil { + return old, false, err + } + if old.AvatarURL == avatarURL { + return old, false, nil + } stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) - _, err = stmt.ExecContext(ctx, avatarURL, localpart) - return + err = stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName) + return profile, true, err } func (s *profilesStatements) SetDisplayName( ctx context.Context, txn *sql.Tx, localpart string, displayName string, -) (err error) { +) (*authtypes.Profile, bool, error) { + profile := &authtypes.Profile{ + Localpart: localpart, + DisplayName: displayName, + } + old, err := s.SelectProfileByLocalpart(ctx, localpart) + if err != nil { + return old, false, err + } + if old.DisplayName == displayName { + return old, false, nil + } stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) - _, err = stmt.ExecContext(ctx, displayName, localpart) - return + err = stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL) + return profile, true, err } func (s *profilesStatements) SelectProfilesBySearch( diff --git a/userapi/storage/sqlite3/stats_table.go b/userapi/storage/sqlite3/stats_table.go index 8aa1746c5..35e3c653e 100644 --- a/userapi/storage/sqlite3/stats_table.go +++ b/userapi/storage/sqlite3/stats_table.go @@ -46,7 +46,7 @@ CREATE INDEX IF NOT EXISTS userapi_daily_visits_localpart_timestamp_idx ON usera const countUsersLastSeenAfterSQL = "" + "SELECT COUNT(*) FROM (" + - " SELECT localpart FROM device_devices WHERE last_seen_ts > $1 " + + " SELECT localpart FROM userapi_devices WHERE last_seen_ts > $1 " + " GROUP BY localpart" + " ) u" @@ -63,7 +63,7 @@ R30Users counts the number of 30 day retained users, defined as: const countR30UsersSQL = ` SELECT platform, COUNT(*) FROM ( SELECT users.localpart, platform, users.created_ts, MAX(uip.last_seen_ts) - FROM account_accounts users + FROM userapi_accounts users INNER JOIN (SELECT localpart, last_seen_ts, @@ -76,7 +76,7 @@ SELECT platform, COUNT(*) FROM ( ELSE 'unknown' END AS platform - FROM device_devices + FROM userapi_devices ) uip ON users.localpart = uip.localpart AND users.account_type <> 4 @@ -126,7 +126,7 @@ GROUP BY client_type ` const countUserByAccountTypeSQL = ` -SELECT COUNT(*) FROM account_accounts WHERE account_type IN ($1) +SELECT COUNT(*) FROM userapi_accounts WHERE account_type IN ($1) ` // $1 = Guest AccountType @@ -139,7 +139,7 @@ SELECT user_type, COUNT(*) AS count FROM ( WHEN account_type = $4 AND appservice_id IS NULL THEN 'guest' WHEN account_type IN ($5) AND appservice_id IS NOT NULL THEN 'bridged' END AS user_type - FROM account_accounts + FROM userapi_accounts WHERE created_ts > $8 ) AS t GROUP BY user_type ` @@ -148,14 +148,14 @@ SELECT user_type, COUNT(*) AS count FROM ( const updateUserDailyVisitsSQL = ` INSERT INTO userapi_daily_visits(localpart, device_id, timestamp, user_agent) SELECT u.localpart, u.device_id, $1, MAX(u.user_agent) - FROM device_devices AS u + FROM userapi_devices AS u LEFT JOIN ( SELECT localpart, device_id, timestamp FROM userapi_daily_visits WHERE timestamp = $1 ) udv ON u.localpart = udv.localpart AND u.device_id = udv.device_id - INNER JOIN device_devices d ON d.localpart = u.localpart - INNER JOIN account_accounts a ON a.localpart = u.localpart + INNER JOIN userapi_devices d ON d.localpart = u.localpart + INNER JOIN userapi_accounts a ON a.localpart = u.localpart WHERE $2 <= d.last_seen_ts AND d.last_seen_ts < $3 AND a.account_type in (1, 3) GROUP BY u.localpart, u.device_id diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index 78b7ce588..dd33dc0cf 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage/shared" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" ) // NewDatabase creates a new accounts and profiles database @@ -34,6 +35,16 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, return nil, err } + m := sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: "userapi: rename tables", + Up: deltas.UpRenameTables, + Down: deltas.DownRenameTables, + }) + if err = m.Up(base.Context()); err != nil { + return nil, err + } + accountDataTable, err := NewSQLiteAccountDataTable(db) if err != nil { return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err) diff --git a/userapi/storage/sqlite3/threepid_table.go b/userapi/storage/sqlite3/threepid_table.go index fa174eed5..73af139db 100644 --- a/userapi/storage/sqlite3/threepid_table.go +++ b/userapi/storage/sqlite3/threepid_table.go @@ -27,7 +27,7 @@ import ( const threepidSchema = ` -- Stores data about third party identifiers -CREATE TABLE IF NOT EXISTS account_threepid ( +CREATE TABLE IF NOT EXISTS userapi_threepids ( -- The third party identifier threepid TEXT NOT NULL, -- The 3PID medium @@ -38,20 +38,20 @@ CREATE TABLE IF NOT EXISTS account_threepid ( PRIMARY KEY(threepid, medium) ); -CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart); +CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart); ` const selectLocalpartForThreePIDSQL = "" + - "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2" + "SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2" const selectThreePIDsForLocalpartSQL = "" + - "SELECT threepid, medium FROM account_threepid WHERE localpart = $1" + "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1" const insertThreePIDSQL = "" + - "INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)" + "INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)" const deleteThreePIDSQL = "" + - "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2" + "DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2" type threepidStatements struct { db *sql.DB diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 3d6abd0a6..42e35d1ab 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -16,6 +16,7 @@ import ( "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage/tables" @@ -29,14 +30,18 @@ var ( ) func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { + base, baseclose := testrig.CreateBaseDendrite(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{ + db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server") if err != nil { t.Fatalf("NewUserAPIDatabase returned %s", err) } - return db, close + return db, func() { + close() + baseclose() + } } // Tests storing and getting account data @@ -377,15 +382,23 @@ func Test_Profile(t *testing.T) { // set avatar & displayname wantProfile.DisplayName = "Alice" - wantProfile.AvatarURL = "mxc://aliceAvatar" - err = db.SetDisplayName(ctx, aliceLocalpart, "Alice") - assert.NoError(t, err, "unable to set displayname") - err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") - assert.NoError(t, err, "unable to set avatar url") - // verify profile - gotProfile, err = db.GetProfileByLocalpart(ctx, aliceLocalpart) - assert.NoError(t, err, "unable to get profile by localpart") + gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, "Alice") assert.Equal(t, wantProfile, gotProfile) + assert.NoError(t, err, "unable to set displayname") + assert.True(t, changed) + + wantProfile.AvatarURL = "mxc://aliceAvatar" + gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") + assert.NoError(t, err, "unable to set avatar url") + assert.Equal(t, wantProfile, gotProfile) + assert.True(t, changed) + + // Setting the same avatar again doesn't change anything + wantProfile.AvatarURL = "mxc://aliceAvatar" + gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") + assert.NoError(t, err, "unable to set avatar url") + assert.Equal(t, wantProfile, gotProfile) + assert.False(t, changed) // search profiles searchRes, err := db.SearchProfiles(ctx, "Alice", 2) diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index cc4287997..1b239e442 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -84,8 +84,8 @@ type OpenIDTable interface { type ProfileTable interface { InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) - SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (err error) - SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (err error) + SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, bool, error) + SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, bool, error) SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) } diff --git a/userapi/storage/tables/stats_table_test.go b/userapi/storage/tables/stats_table_test.go index 11521c8b0..c4aec552c 100644 --- a/userapi/storage/tables/stats_table_test.go +++ b/userapi/storage/tables/stats_table_test.go @@ -106,7 +106,7 @@ func mustUpdateDeviceLastSeen( timestamp time.Time, ) { t.Helper() - _, err := db.ExecContext(ctx, "UPDATE device_devices SET last_seen_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart) + _, err := db.ExecContext(ctx, "UPDATE userapi_devices SET last_seen_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart) if err != nil { t.Fatalf("unable to update device last seen") } @@ -119,7 +119,7 @@ func mustUserUpdateRegistered( localpart string, timestamp time.Time, ) { - _, err := db.ExecContext(ctx, "UPDATE account_accounts SET created_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart) + _, err := db.ExecContext(ctx, "UPDATE userapi_accounts SET created_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart) if err != nil { t.Fatalf("unable to update device last seen") } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 984fe8854..aaa93f45b 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -23,13 +23,15 @@ import ( "time" "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/userapi" - "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/gomatrixserverlib" "golang.org/x/crypto/bcrypt" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi" + "github.com/matrix-org/dendrite/userapi/inthttp" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/internal" @@ -48,9 +50,9 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap if opts.loginTokenLifetime == 0 { opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond } + base, baseclose := testrig.CreateBaseDendrite(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType) - - accountDB, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{ + accountDB, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "") if err != nil { @@ -64,9 +66,12 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap } return &internal.UserInternalAPI{ - DB: accountDB, - ServerName: cfg.Matrix.ServerName, - }, accountDB, close + DB: accountDB, + ServerName: cfg.Matrix.ServerName, + }, accountDB, func() { + close() + baseclose() + } } func TestQueryProfile(t *testing.T) { @@ -79,10 +84,10 @@ func TestQueryProfile(t *testing.T) { if err != nil { t.Fatalf("failed to make account: %s", err) } - if err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil { + if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil { t.Fatalf("failed to set avatar url: %s", err) } - if err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil { + if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil { t.Fatalf("failed to set display name: %s", err) }