Merge pull request #44 from globekeeper/release/upstream_0.10.4

Release/upstream 0.10.4
This commit is contained in:
Daniel Aloni 2022-10-27 17:57:41 +03:00 committed by GitHub
commit bc17086f63
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
80 changed files with 2619 additions and 743 deletions

View file

@ -1,5 +1,23 @@
# Changelog # Changelog
## Dendrite 0.10.4 (2022-10-21)
### Features
* Various tables belonging to the user API will be renamed so that they are namespaced with the `userapi_` prefix
* Note that, after upgrading to this version, you should not revert to an older version of Dendrite as the database changes **will not** be reverted automatically
* The backoff and retry behaviour in the federation API has been refactored and improved
### Fixes
* Private read receipt support is now advertised in the client `/versions` endpoint
* Private read receipts will now clear notification counts properly
* A bug where a false `leave` membership transition was inserted into the timeline after accepting an invite has been fixed
* Some panics caused by concurrent map writes in the key server have been fixed
* The sync API now calculates membership transitions from state deltas more accurately
* Transaction IDs are now scoped to endpoints, which should fix some bugs where transaction ID reuse could cause nonsensical cached responses from some endpoints
* The length of the `type`, `sender`, `state_key` and `room_id` fields in events are now verified by number of bytes rather than codepoints after a spec clarification, reverting a change made in Dendrite 0.9.6
## Dendrite 0.10.3 (2022-10-14) ## Dendrite 0.10.3 (2022-10-14)
### Features ### Features

View file

@ -19,6 +19,8 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/matrix-org/gomatrixserverlib"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
@ -27,7 +29,6 @@ import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -120,20 +121,6 @@ func SetAvatarURL(
} }
} }
res := &userapi.QueryProfileResponse{}
err = profileAPI.QueryProfile(req.Context(), &userapi.QueryProfileRequest{
UserID: userID,
}, res)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.QueryProfile failed")
return jsonerror.InternalServerError()
}
oldProfile := &authtypes.Profile{
Localpart: localpart,
DisplayName: res.DisplayName,
AvatarURL: res.AvatarURL,
}
setRes := &userapi.PerformSetAvatarURLResponse{} setRes := &userapi.PerformSetAvatarURLResponse{}
if err = profileAPI.SetAvatarURL(req.Context(), &userapi.PerformSetAvatarURLRequest{ if err = profileAPI.SetAvatarURL(req.Context(), &userapi.PerformSetAvatarURLRequest{
Localpart: localpart, Localpart: localpart,
@ -142,41 +129,17 @@ func SetAvatarURL(
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetAvatarURL failed") util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetAvatarURL failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
// No need to build new membership events, since nothing changed
var roomsRes api.QueryRoomsForUserResponse if !setRes.Changed {
err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{
UserID: device.UserID,
WantMembership: "join",
}, &roomsRes)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed")
return jsonerror.InternalServerError()
}
newProfile := authtypes.Profile{
Localpart: localpart,
DisplayName: oldProfile.DisplayName,
AvatarURL: r.AvatarURL,
}
events, err := buildMembershipEvents(
req.Context(), roomsRes.RoomIDs, newProfile, userID, cfg, evTime, rsAPI,
)
switch e := err.(type) {
case nil:
case gomatrixserverlib.BadJSONError:
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusOK,
JSON: jsonerror.BadJSON(e.Error()), JSON: struct{}{},
} }
default:
util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvents failed")
return jsonerror.InternalServerError()
} }
if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil { response, err := updateProfile(req.Context(), rsAPI, device, setRes.Profile, userID, cfg, evTime)
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") if err != nil {
return jsonerror.InternalServerError() return response
} }
return util.JSONResponse{ return util.JSONResponse{
@ -249,47 +212,51 @@ func SetDisplayName(
} }
} }
pRes := &userapi.QueryProfileResponse{} profileRes := &userapi.PerformUpdateDisplayNameResponse{}
err = profileAPI.QueryProfile(req.Context(), &userapi.QueryProfileRequest{
UserID: userID,
}, pRes)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.QueryProfile failed")
return jsonerror.InternalServerError()
}
oldProfile := &authtypes.Profile{
Localpart: localpart,
DisplayName: pRes.DisplayName,
AvatarURL: pRes.AvatarURL,
}
err = profileAPI.SetDisplayName(req.Context(), &userapi.PerformUpdateDisplayNameRequest{ err = profileAPI.SetDisplayName(req.Context(), &userapi.PerformUpdateDisplayNameRequest{
Localpart: localpart, Localpart: localpart,
DisplayName: r.DisplayName, DisplayName: r.DisplayName,
}, &struct{}{}) }, profileRes)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed") util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
// No need to build new membership events, since nothing changed
if !profileRes.Changed {
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
response, err := updateProfile(req.Context(), rsAPI, device, profileRes.Profile, userID, cfg, evTime)
if err != nil {
return response
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
func updateProfile(
ctx context.Context, rsAPI api.ClientRoomserverAPI, device *userapi.Device,
profile *authtypes.Profile,
userID string, cfg *config.ClientAPI, evTime time.Time,
) (util.JSONResponse, error) {
var res api.QueryRoomsForUserResponse var res api.QueryRoomsForUserResponse
err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{ err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
UserID: device.UserID, UserID: device.UserID,
WantMembership: "join", WantMembership: "join",
}, &res) }, &res)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed") util.GetLogger(ctx).WithError(err).Error("QueryRoomsForUser failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError(), err
}
newProfile := authtypes.Profile{
Localpart: localpart,
DisplayName: r.DisplayName,
AvatarURL: oldProfile.AvatarURL,
} }
events, err := buildMembershipEvents( events, err := buildMembershipEvents(
req.Context(), res.RoomIDs, newProfile, userID, cfg, evTime, rsAPI, ctx, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI,
) )
switch e := err.(type) { switch e := err.(type) {
case nil: case nil:
@ -297,21 +264,17 @@ func SetDisplayName(
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(e.Error()), JSON: jsonerror.BadJSON(e.Error()),
} }, e
default: default:
util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvents failed") util.GetLogger(ctx).WithError(err).Error("buildMembershipEvents failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError(), e
} }
if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil { if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") util.GetLogger(ctx).WithError(err).Error("SendEvents failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError(), err
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
} }
return util.JSONResponse{}, nil
} }
// getProfile gets the full profile of a user by querying the database or a // getProfile gets the full profile of a user by querying the database or a

View file

@ -19,6 +19,9 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
@ -26,8 +29,6 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
) )
type redactionContent struct { type redactionContent struct {
@ -51,7 +52,7 @@ func SendRedaction(
if txnID != nil { if txnID != nil {
// Try to fetch response from transactionsCache // Try to fetch response from transactionsCache
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok { if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok {
return *res return *res
} }
} }
@ -144,7 +145,7 @@ func SendRedaction(
// Add response to transactionsCache // Add response to transactionsCache
if txnID != nil { if txnID != nil {
txnCache.AddTransaction(device.AccessToken, *txnID, &res) txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res)
} }
return res return res

View file

@ -72,6 +72,7 @@ func Setup(
unstableFeatures := map[string]bool{ unstableFeatures := map[string]bool{
"org.matrix.e2e_cross_signing": true, "org.matrix.e2e_cross_signing": true,
"org.matrix.msc2285.stable": true,
} }
for _, msc := range cfg.MSCs.MSCs { for _, msc := range cfg.MSCs.MSCs {
unstableFeatures["org.matrix."+msc] = true unstableFeatures["org.matrix."+msc] = true
@ -179,7 +180,7 @@ func Setup(
// server notifications // server notifications
if cfg.Matrix.ServerNotices.Enabled { if cfg.Matrix.ServerNotices.Enabled {
logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice") logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice")
serverNotificationSender, err := getSenderDevice(context.Background(), userAPI, cfg) serverNotificationSender, err := getSenderDevice(context.Background(), rsAPI, userAPI, cfg)
if err != nil { if err != nil {
logrus.WithError(err).Fatal("unable to get account for sending sending server notices") logrus.WithError(err).Fatal("unable to get account for sending sending server notices")
} }

View file

@ -86,7 +86,7 @@ func SendEvent(
if txnID != nil { if txnID != nil {
// Try to fetch response from transactionsCache // Try to fetch response from transactionsCache
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok { if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok {
return *res return *res
} }
} }
@ -206,7 +206,7 @@ func SendEvent(
} }
// Add response to transactionsCache // Add response to transactionsCache
if txnID != nil { if txnID != nil {
txnCache.AddTransaction(device.AccessToken, *txnID, &res) txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res)
} }
// Take a note of how long it took to generate the event vs submit // Take a note of how long it took to generate the event vs submit

View file

@ -16,12 +16,13 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"github.com/matrix-org/util"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/internal/transactions" "github.com/matrix-org/dendrite/internal/transactions"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
) )
// SendToDevice handles PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId} // SendToDevice handles PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId}
@ -33,7 +34,7 @@ func SendToDevice(
eventType string, txnID *string, eventType string, txnID *string,
) util.JSONResponse { ) util.JSONResponse {
if txnID != nil { if txnID != nil {
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok { if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok {
return *res return *res
} }
} }
@ -63,7 +64,7 @@ func SendToDevice(
} }
if txnID != nil { if txnID != nil {
txnCache.AddTransaction(device.AccessToken, *txnID, &res) txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res)
} }
return res return res

View file

@ -21,7 +21,6 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/matrix-org/dendrite/roomserver/version"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/tokens" "github.com/matrix-org/gomatrixserverlib/tokens"
@ -29,6 +28,8 @@ import (
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/roomserver/version"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
@ -73,7 +74,7 @@ func SendServerNotice(
if txnID != nil { if txnID != nil {
// Try to fetch response from transactionsCache // Try to fetch response from transactionsCache
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok { if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok {
return *res return *res
} }
} }
@ -251,7 +252,7 @@ func SendServerNotice(
} }
// Add response to transactionsCache // Add response to transactionsCache
if txnID != nil { if txnID != nil {
txnCache.AddTransaction(device.AccessToken, *txnID, &res) txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res)
} }
// Take a note of how long it took to generate the event vs submit // Take a note of how long it took to generate the event vs submit
@ -276,6 +277,7 @@ func (r sendServerNoticeRequest) valid() (ok bool) {
// It returns an userapi.Device, which is used for building the event // It returns an userapi.Device, which is used for building the event
func getSenderDevice( func getSenderDevice(
ctx context.Context, ctx context.Context,
rsAPI api.ClientRoomserverAPI,
userAPI userapi.ClientUserAPI, userAPI userapi.ClientUserAPI,
cfg *config.ClientAPI, cfg *config.ClientAPI,
) (*userapi.Device, error) { ) (*userapi.Device, error) {
@ -290,16 +292,32 @@ func getSenderDevice(
return nil, err return nil, err
} }
// set the avatarurl for the user // Set the avatarurl for the user
res := &userapi.PerformSetAvatarURLResponse{} avatarRes := &userapi.PerformSetAvatarURLResponse{}
if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{ if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{
Localpart: cfg.Matrix.ServerNotices.LocalPart, Localpart: cfg.Matrix.ServerNotices.LocalPart,
AvatarURL: cfg.Matrix.ServerNotices.AvatarURL, AvatarURL: cfg.Matrix.ServerNotices.AvatarURL,
}, res); err != nil { }, avatarRes); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed") util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed")
return nil, err return nil, err
} }
profile := avatarRes.Profile
// Set the displayname for the user
displayNameRes := &userapi.PerformUpdateDisplayNameResponse{}
if err = userAPI.SetDisplayName(ctx, &userapi.PerformUpdateDisplayNameRequest{
Localpart: cfg.Matrix.ServerNotices.LocalPart,
DisplayName: cfg.Matrix.ServerNotices.DisplayName,
}, displayNameRes); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.SetDisplayName failed")
return nil, err
}
if displayNameRes.Changed {
profile.DisplayName = cfg.Matrix.ServerNotices.DisplayName
}
// Check if we got existing devices // Check if we got existing devices
deviceRes := &userapi.QueryDevicesResponse{} deviceRes := &userapi.QueryDevicesResponse{}
err = userAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{ err = userAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{
@ -309,7 +327,15 @@ func getSenderDevice(
return nil, err return nil, err
} }
// We've got an existing account, return the first device of it
if len(deviceRes.Devices) > 0 { if len(deviceRes.Devices) > 0 {
// If there were changes to the profile, create a new membership event
if displayNameRes.Changed || avatarRes.Changed {
_, err = updateProfile(ctx, rsAPI, &deviceRes.Devices[0], profile, accRes.Account.UserID, cfg, time.Now())
if err != nil {
return nil, err
}
}
return &deviceRes.Devices[0], nil return &deviceRes.Devices[0], nil
} }

View file

@ -179,7 +179,10 @@ func sharedSecretRegister(sharedSecret, serverURL, localpart, password string, a
body, _ = io.ReadAll(regResp.Body) body, _ = io.ReadAll(regResp.Body)
return "", fmt.Errorf(gjson.GetBytes(body, "error").Str) return "", fmt.Errorf(gjson.GetBytes(body, "error").Str)
} }
r, _ := io.ReadAll(regResp.Body) r, err := io.ReadAll(regResp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response body (HTTP %d): %w", regResp.StatusCode, err)
}
return gjson.GetBytes(r, "access_token").Str, nil return gjson.GetBytes(r, "access_token").Str, nil
} }

View file

@ -231,9 +231,9 @@ GEM
jekyll-seo-tag (~> 2.1) jekyll-seo-tag (~> 2.1)
minitest (5.15.0) minitest (5.15.0)
multipart-post (2.1.1) multipart-post (2.1.1)
nokogiri (1.13.6-arm64-darwin) nokogiri (1.13.9-arm64-darwin)
racc (~> 1.4) racc (~> 1.4)
nokogiri (1.13.6-x86_64-linux) nokogiri (1.13.9-x86_64-linux)
racc (~> 1.4) racc (~> 1.4)
octokit (4.22.0) octokit (4.22.0)
faraday (>= 0.9) faraday (>= 0.9)

View file

@ -116,17 +116,14 @@ func NewInternalAPI(
_ = federationDB.RemoveAllServersFromBlacklist() _ = federationDB.RemoveAllServersFromBlacklist()
} }
stats := &statistics.Statistics{ stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1)
DB: federationDB,
FailuresUntilBlacklist: cfg.FederationMaxRetries,
}
js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
queues := queue.NewOutgoingQueues( queues := queue.NewOutgoingQueues(
federationDB, base.ProcessContext, federationDB, base.ProcessContext,
cfg.Matrix.DisableFederation, cfg.Matrix.DisableFederation,
cfg.Matrix.ServerName, federation, rsAPI, stats, cfg.Matrix.ServerName, federation, rsAPI, &stats,
&queue.SigningInfo{ &queue.SigningInfo{
KeyID: cfg.Matrix.KeyID, KeyID: cfg.Matrix.KeyID,
PrivateKey: cfg.Matrix.PrivateKey, PrivateKey: cfg.Matrix.PrivateKey,
@ -183,5 +180,5 @@ func NewInternalAPI(
} }
time.AfterFunc(time.Minute, cleanExpiredEDUs) time.AfterFunc(time.Minute, cleanExpiredEDUs)
return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, stats, caches, queues, keyRing) return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, &stats, caches, queues, keyRing)
} }

View file

@ -21,21 +21,22 @@ import (
"sync" "sync"
"time" "time"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"go.uber.org/atomic"
fedapi "github.com/matrix-org/dendrite/federationapi/api" fedapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"go.uber.org/atomic"
) )
const ( const (
maxPDUsPerTransaction = 50 maxPDUsPerTransaction = 50
maxEDUsPerTransaction = 50 maxEDUsPerTransaction = 100
maxPDUsInMemory = 128 maxPDUsInMemory = 128
maxEDUsInMemory = 128 maxEDUsInMemory = 128
queueIdleTimeout = time.Second * 30 queueIdleTimeout = time.Second * 30
@ -64,7 +65,6 @@ type destinationQueue struct {
pendingPDUs []*queuedPDU // PDUs waiting to be sent pendingPDUs []*queuedPDU // PDUs waiting to be sent
pendingEDUs []*queuedEDU // EDUs waiting to be sent pendingEDUs []*queuedEDU // EDUs waiting to be sent
pendingMutex sync.RWMutex // protects pendingPDUs and pendingEDUs pendingMutex sync.RWMutex // protects pendingPDUs and pendingEDUs
interruptBackoff chan bool // interrupts backoff
} }
// Send event adds the event to the pending queue for the destination. // Send event adds the event to the pending queue for the destination.
@ -75,21 +75,7 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re
logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination) logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination)
return return
} }
// Create a database entry that associates the given PDU NID with
// this destination queue. We'll then be able to retrieve the PDU
// later.
if err := oq.db.AssociatePDUWithDestination(
oq.process.Context(),
"", // TODO: remove this, as we don't need to persist the transaction ID
oq.destination, // the destination server name
receipt, // NIDs from federationapi_queue_json table
); err != nil {
logrus.WithError(err).Errorf("failed to associate PDU %q with destination %q", event.EventID(), oq.destination)
return
}
// Check if the destination is blacklisted. If it isn't then wake
// up the queue.
if !oq.statistics.Blacklisted() {
// If there's room in memory to hold the event then add it to the // If there's room in memory to hold the event then add it to the
// list. // list.
oq.pendingMutex.Lock() oq.pendingMutex.Lock()
@ -102,12 +88,9 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re
oq.overflowed.Store(true) oq.overflowed.Store(true)
} }
oq.pendingMutex.Unlock() oq.pendingMutex.Unlock()
// Wake up the queue if it's asleep.
oq.wakeQueueIfNeeded() if !oq.backingOff.Load() {
select { oq.wakeQueueAndNotify()
case oq.notify <- struct{}{}:
default:
}
} }
} }
@ -119,22 +102,7 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share
logrus.Errorf("attempt to send nil EDU with destination %q", oq.destination) logrus.Errorf("attempt to send nil EDU with destination %q", oq.destination)
return return
} }
// Create a database entry that associates the given PDU NID with
// this destination queue. We'll then be able to retrieve the PDU
// later.
if err := oq.db.AssociateEDUWithDestination(
oq.process.Context(),
oq.destination, // the destination server name
receipt, // NIDs from federationapi_queue_json table
event.Type,
nil, // this will use the default expireEDUTypes map
); err != nil {
logrus.WithError(err).Errorf("failed to associate EDU with destination %q", oq.destination)
return
}
// Check if the destination is blacklisted. If it isn't then wake
// up the queue.
if !oq.statistics.Blacklisted() {
// If there's room in memory to hold the event then add it to the // If there's room in memory to hold the event then add it to the
// list. // list.
oq.pendingMutex.Lock() oq.pendingMutex.Lock()
@ -147,24 +115,47 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share
oq.overflowed.Store(true) oq.overflowed.Store(true)
} }
oq.pendingMutex.Unlock() oq.pendingMutex.Unlock()
if !oq.backingOff.Load() {
oq.wakeQueueAndNotify()
}
}
// handleBackoffNotifier is registered as the backoff notification
// callback with Statistics. It will wakeup and notify the queue
// if the queue is currently backing off.
func (oq *destinationQueue) handleBackoffNotifier() {
// Only wake up the queue if it is backing off.
// Otherwise there is no pending work for the queue to handle
// so waking the queue would be a waste of resources.
if oq.backingOff.Load() {
oq.wakeQueueAndNotify()
}
}
// wakeQueueAndNotify ensures the destination queue is running and notifies it
// that there is pending work.
func (oq *destinationQueue) wakeQueueAndNotify() {
// Wake up the queue if it's asleep. // Wake up the queue if it's asleep.
oq.wakeQueueIfNeeded() oq.wakeQueueIfNeeded()
// Notify the queue that there are events ready to send.
select { select {
case oq.notify <- struct{}{}: case oq.notify <- struct{}{}:
default: default:
} }
} }
}
// wakeQueueIfNeeded will wake up the destination queue if it is // wakeQueueIfNeeded will wake up the destination queue if it is
// not already running. If it is running but it is backing off // not already running. If it is running but it is backing off
// then we will interrupt the backoff, causing any federation // then we will interrupt the backoff, causing any federation
// requests to retry. // requests to retry.
func (oq *destinationQueue) wakeQueueIfNeeded() { func (oq *destinationQueue) wakeQueueIfNeeded() {
// If we are backing off then interrupt the backoff. // Clear the backingOff flag and update the backoff metrics if it was set.
if oq.backingOff.CompareAndSwap(true, false) { if oq.backingOff.CompareAndSwap(true, false) {
oq.interruptBackoff <- true destinationQueueBackingOff.Dec()
} }
// If we aren't running then wake up the queue. // If we aren't running then wake up the queue.
if !oq.running.Load() { if !oq.running.Load() {
// Start the queue. // Start the queue.
@ -196,38 +187,54 @@ func (oq *destinationQueue) getPendingFromDatabase() {
gotEDUs[edu.receipt.String()] = struct{}{} gotEDUs[edu.receipt.String()] = struct{}{}
} }
overflowed := false
if pduCapacity := maxPDUsInMemory - len(oq.pendingPDUs); pduCapacity > 0 { if pduCapacity := maxPDUsInMemory - len(oq.pendingPDUs); pduCapacity > 0 {
// We have room in memory for some PDUs - let's request no more than that. // We have room in memory for some PDUs - let's request no more than that.
if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, pduCapacity); err == nil { if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, maxPDUsInMemory); err == nil {
if len(pdus) == maxPDUsInMemory {
overflowed = true
}
for receipt, pdu := range pdus { for receipt, pdu := range pdus {
if _, ok := gotPDUs[receipt.String()]; ok { if _, ok := gotPDUs[receipt.String()]; ok {
continue continue
} }
oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{receipt, pdu}) oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{receipt, pdu})
retrieved = true retrieved = true
if len(oq.pendingPDUs) == maxPDUsInMemory {
break
}
} }
} else { } else {
logrus.WithError(err).Errorf("Failed to get pending PDUs for %q", oq.destination) logrus.WithError(err).Errorf("Failed to get pending PDUs for %q", oq.destination)
} }
} }
if eduCapacity := maxEDUsInMemory - len(oq.pendingEDUs); eduCapacity > 0 { if eduCapacity := maxEDUsInMemory - len(oq.pendingEDUs); eduCapacity > 0 {
// We have room in memory for some EDUs - let's request no more than that. // We have room in memory for some EDUs - let's request no more than that.
if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, eduCapacity); err == nil { if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, maxEDUsInMemory); err == nil {
if len(edus) == maxEDUsInMemory {
overflowed = true
}
for receipt, edu := range edus { for receipt, edu := range edus {
if _, ok := gotEDUs[receipt.String()]; ok { if _, ok := gotEDUs[receipt.String()]; ok {
continue continue
} }
oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{receipt, edu}) oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{receipt, edu})
retrieved = true retrieved = true
if len(oq.pendingEDUs) == maxEDUsInMemory {
break
}
} }
} else { } else {
logrus.WithError(err).Errorf("Failed to get pending EDUs for %q", oq.destination) logrus.WithError(err).Errorf("Failed to get pending EDUs for %q", oq.destination)
} }
} }
// If we've retrieved all of the events from the database with room to spare // If we've retrieved all of the events from the database with room to spare
// in memory then we'll no longer consider this queue to be overflowed. // in memory then we'll no longer consider this queue to be overflowed.
if len(oq.pendingPDUs) < maxPDUsInMemory && len(oq.pendingEDUs) < maxEDUsInMemory { if !overflowed {
oq.overflowed.Store(false) oq.overflowed.Store(false)
} else {
} }
// If we've retrieved some events then notify the destination queue goroutine. // If we've retrieved some events then notify the destination queue goroutine.
if retrieved { if retrieved {
@ -238,6 +245,24 @@ func (oq *destinationQueue) getPendingFromDatabase() {
} }
} }
// checkNotificationsOnClose checks for any remaining notifications
// and starts a new backgroundSend goroutine if any exist.
func (oq *destinationQueue) checkNotificationsOnClose() {
// NOTE : If we are stopping the queue due to blacklist then it
// doesn't matter if we have been notified of new work since
// this queue instance will be deleted anyway.
if !oq.statistics.Blacklisted() {
select {
case <-oq.notify:
// We received a new notification in between the
// idle timeout firing and stopping the goroutine.
// Immediately restart the queue.
oq.wakeQueueAndNotify()
default:
}
}
}
// backgroundSend is the worker goroutine for sending events. // backgroundSend is the worker goroutine for sending events.
func (oq *destinationQueue) backgroundSend() { func (oq *destinationQueue) backgroundSend() {
// Check if a worker is already running, and if it isn't, then // Check if a worker is already running, and if it isn't, then
@ -245,10 +270,17 @@ func (oq *destinationQueue) backgroundSend() {
if !oq.running.CompareAndSwap(false, true) { if !oq.running.CompareAndSwap(false, true) {
return return
} }
// Register queue cleanup functions.
// NOTE : The ordering here is very intentional.
defer oq.checkNotificationsOnClose()
defer oq.running.Store(false)
destinationQueueRunning.Inc() destinationQueueRunning.Inc()
defer destinationQueueRunning.Dec() defer destinationQueueRunning.Dec()
defer oq.queues.clearQueue(oq)
defer oq.running.Store(false) idleTimeout := time.NewTimer(queueIdleTimeout)
defer idleTimeout.Stop()
// Mark the queue as overflowed, so we will consult the database // Mark the queue as overflowed, so we will consult the database
// to see if there's anything new to send. // to see if there's anything new to send.
@ -261,59 +293,33 @@ func (oq *destinationQueue) backgroundSend() {
oq.getPendingFromDatabase() oq.getPendingFromDatabase()
} }
// Reset the queue idle timeout.
if !idleTimeout.Stop() {
select {
case <-idleTimeout.C:
default:
}
}
idleTimeout.Reset(queueIdleTimeout)
// If we have nothing to do then wait either for incoming events, or // If we have nothing to do then wait either for incoming events, or
// until we hit an idle timeout. // until we hit an idle timeout.
select { select {
case <-oq.notify: case <-oq.notify:
// There's work to do, either because getPendingFromDatabase // There's work to do, either because getPendingFromDatabase
// told us there is, or because a new event has come in via // told us there is, a new event has come in via sendEvent/sendEDU,
// sendEvent/sendEDU. // or we are backing off and it is time to retry.
case <-time.After(queueIdleTimeout): case <-idleTimeout.C:
// The worker is idle so stop the goroutine. It'll get // The worker is idle so stop the goroutine. It'll get
// restarted automatically the next time we have an event to // restarted automatically the next time we have an event to
// send. // send.
return return
case <-oq.process.Context().Done(): case <-oq.process.Context().Done():
// The parent process is shutting down, so stop. // The parent process is shutting down, so stop.
oq.statistics.ClearBackoff()
return return
} }
// If we are backing off this server then wait for the
// backoff duration to complete first, or until explicitly
// told to retry.
until, blacklisted := oq.statistics.BackoffInfo()
if blacklisted {
// It's been suggested that we should give up because the backoff
// has exceeded a maximum allowable value. Clean up the in-memory
// buffers at this point. The PDU clean-up is already on a defer.
logrus.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination)
oq.pendingMutex.Lock()
for i := range oq.pendingPDUs {
oq.pendingPDUs[i] = nil
}
for i := range oq.pendingEDUs {
oq.pendingEDUs[i] = nil
}
oq.pendingPDUs = nil
oq.pendingEDUs = nil
oq.pendingMutex.Unlock()
return
}
if until != nil && until.After(time.Now()) {
// We haven't backed off yet, so wait for the suggested amount of
// time.
duration := time.Until(*until)
logrus.Debugf("Backing off %q for %s", oq.destination, duration)
oq.backingOff.Store(true)
destinationQueueBackingOff.Inc()
select {
case <-time.After(duration):
case <-oq.interruptBackoff:
}
destinationQueueBackingOff.Dec()
oq.backingOff.Store(false)
}
// Work out which PDUs/EDUs to include in the next transaction. // Work out which PDUs/EDUs to include in the next transaction.
oq.pendingMutex.RLock() oq.pendingMutex.RLock()
pduCount := len(oq.pendingPDUs) pduCount := len(oq.pendingPDUs)
@ -328,99 +334,52 @@ func (oq *destinationQueue) backgroundSend() {
toSendEDUs := oq.pendingEDUs[:eduCount] toSendEDUs := oq.pendingEDUs[:eduCount]
oq.pendingMutex.RUnlock() oq.pendingMutex.RUnlock()
// If we didn't get anything from the database and there are no
// pending EDUs then there's nothing to do - stop here.
if pduCount == 0 && eduCount == 0 {
continue
}
// If we have pending PDUs or EDUs then construct a transaction. // If we have pending PDUs or EDUs then construct a transaction.
// Try sending the next transaction and see what happens. // Try sending the next transaction and see what happens.
transaction, pc, ec, terr := oq.nextTransaction(toSendPDUs, toSendEDUs) terr := oq.nextTransaction(toSendPDUs, toSendEDUs)
if terr != nil { if terr != nil {
// We failed to send the transaction. Mark it as a failure. // We failed to send the transaction. Mark it as a failure.
oq.statistics.Failure() _, blacklisted := oq.statistics.Failure()
if !blacklisted {
} else if transaction { // Register the backoff state and exit the goroutine.
// If we successfully sent the transaction then clear out // It'll get restarted automatically when the backoff
// the pending events and EDUs, and wipe our transaction ID. // completes.
oq.statistics.Success() oq.backingOff.Store(true)
oq.pendingMutex.Lock() destinationQueueBackingOff.Inc()
for i := range oq.pendingPDUs[:pc] { return
oq.pendingPDUs[i] = nil } else {
// Immediately trigger the blacklist logic.
oq.blacklistDestination()
return
} }
for i := range oq.pendingEDUs[:ec] { } else {
oq.pendingEDUs[i] = nil oq.handleTransactionSuccess(pduCount, eduCount)
}
oq.pendingPDUs = oq.pendingPDUs[pc:]
oq.pendingEDUs = oq.pendingEDUs[ec:]
oq.pendingMutex.Unlock()
} }
} }
} }
// nextTransaction creates a new transaction from the pending event // nextTransaction creates a new transaction from the pending event
// queue and sends it. Returns true if a transaction was sent or // queue and sends it.
// false otherwise. // Returns an error if the transaction wasn't sent.
func (oq *destinationQueue) nextTransaction( func (oq *destinationQueue) nextTransaction(
pdus []*queuedPDU, pdus []*queuedPDU,
edus []*queuedEDU, edus []*queuedEDU,
) (bool, int, int, error) { ) error {
// If there's no projected transaction ID then generate one. If
// the transaction succeeds then we'll set it back to "" so that
// we generate a new one next time. If it fails, we'll preserve
// it so that we retry with the same transaction ID.
oq.transactionIDMutex.Lock()
if oq.transactionID == "" {
now := gomatrixserverlib.AsTimestamp(time.Now())
oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
}
oq.transactionIDMutex.Unlock()
// Create the transaction. // Create the transaction.
t := gomatrixserverlib.Transaction{ t, pduReceipts, eduReceipts := oq.createTransaction(pdus, edus)
PDUs: []json.RawMessage{},
EDUs: []gomatrixserverlib.EDU{},
}
t.Origin = oq.origin
t.Destination = oq.destination
t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now())
t.TransactionID = oq.transactionID
// If we didn't get anything from the database and there are no
// pending EDUs then there's nothing to do - stop here.
if len(pdus) == 0 && len(edus) == 0 {
return false, 0, 0, nil
}
var pduReceipts []*shared.Receipt
var eduReceipts []*shared.Receipt
// Go through PDUs that we retrieved from the database, if any,
// and add them into the transaction.
for _, pdu := range pdus {
if pdu == nil || pdu.pdu == nil {
continue
}
// Append the JSON of the event, since this is a json.RawMessage type in the
// gomatrixserverlib.Transaction struct
t.PDUs = append(t.PDUs, pdu.pdu.JSON())
pduReceipts = append(pduReceipts, pdu.receipt)
}
// Do the same for pending EDUS in the queue.
for _, edu := range edus {
if edu == nil || edu.edu == nil {
continue
}
t.EDUs = append(t.EDUs, *edu.edu)
eduReceipts = append(eduReceipts, edu.receipt)
}
logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs)) logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs))
// Try to send the transaction to the destination server. // Try to send the transaction to the destination server.
// TODO: we should check for 500-ish fails vs 400-ish here,
// since we shouldn't queue things indefinitely in response
// to a 400-ish error
ctx, cancel := context.WithTimeout(oq.process.Context(), time.Minute*5) ctx, cancel := context.WithTimeout(oq.process.Context(), time.Minute*5)
defer cancel() defer cancel()
_, err := oq.client.SendTransaction(ctx, t) _, err := oq.client.SendTransaction(ctx, t)
switch err.(type) { switch errResponse := err.(type) {
case nil: case nil:
// Clean up the transaction in the database. // Clean up the transaction in the database.
if pduReceipts != nil { if pduReceipts != nil {
@ -439,16 +398,129 @@ func (oq *destinationQueue) nextTransaction(
oq.transactionIDMutex.Lock() oq.transactionIDMutex.Lock()
oq.transactionID = "" oq.transactionID = ""
oq.transactionIDMutex.Unlock() oq.transactionIDMutex.Unlock()
return true, len(t.PDUs), len(t.EDUs), nil return nil
case gomatrix.HTTPError: case gomatrix.HTTPError:
// Report that we failed to send the transaction and we // Report that we failed to send the transaction and we
// will retry again, subject to backoff. // will retry again, subject to backoff.
return false, 0, 0, err
// TODO: we should check for 500-ish fails vs 400-ish here,
// since we shouldn't queue things indefinitely in response
// to a 400-ish error
code := errResponse.Code
logrus.Debug("Transaction failed with HTTP", code)
return err
default: default:
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"destination": oq.destination, "destination": oq.destination,
logrus.ErrorKey: err, logrus.ErrorKey: err,
}).Debugf("Failed to send transaction %q", t.TransactionID) }).Debugf("Failed to send transaction %q", t.TransactionID)
return false, 0, 0, err return err
}
}
// createTransaction generates a gomatrixserverlib.Transaction from the provided pdus and edus.
// It also returns the associated event receipts so they can be cleaned from the database in
// the case of a successful transaction.
func (oq *destinationQueue) createTransaction(
pdus []*queuedPDU,
edus []*queuedEDU,
) (gomatrixserverlib.Transaction, []*shared.Receipt, []*shared.Receipt) {
// If there's no projected transaction ID then generate one. If
// the transaction succeeds then we'll set it back to "" so that
// we generate a new one next time. If it fails, we'll preserve
// it so that we retry with the same transaction ID.
oq.transactionIDMutex.Lock()
if oq.transactionID == "" {
now := gomatrixserverlib.AsTimestamp(time.Now())
oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
}
oq.transactionIDMutex.Unlock()
t := gomatrixserverlib.Transaction{
PDUs: []json.RawMessage{},
EDUs: []gomatrixserverlib.EDU{},
}
t.Origin = oq.origin
t.Destination = oq.destination
t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now())
t.TransactionID = oq.transactionID
var pduReceipts []*shared.Receipt
var eduReceipts []*shared.Receipt
// Go through PDUs that we retrieved from the database, if any,
// and add them into the transaction.
for _, pdu := range pdus {
// These should never be nil.
if pdu == nil || pdu.pdu == nil {
continue
}
// Append the JSON of the event, since this is a json.RawMessage type in the
// gomatrixserverlib.Transaction struct
t.PDUs = append(t.PDUs, pdu.pdu.JSON())
pduReceipts = append(pduReceipts, pdu.receipt)
}
// Do the same for pending EDUS in the queue.
for _, edu := range edus {
// These should never be nil.
if edu == nil || edu.edu == nil {
continue
}
t.EDUs = append(t.EDUs, *edu.edu)
eduReceipts = append(eduReceipts, edu.receipt)
}
return t, pduReceipts, eduReceipts
}
// blacklistDestination removes all pending PDUs and EDUs that have been cached
// and deletes this queue.
func (oq *destinationQueue) blacklistDestination() {
// It's been suggested that we should give up because the backoff
// has exceeded a maximum allowable value. Clean up the in-memory
// buffers at this point. The PDU clean-up is already on a defer.
logrus.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination)
oq.pendingMutex.Lock()
for i := range oq.pendingPDUs {
oq.pendingPDUs[i] = nil
}
for i := range oq.pendingEDUs {
oq.pendingEDUs[i] = nil
}
oq.pendingPDUs = nil
oq.pendingEDUs = nil
oq.pendingMutex.Unlock()
// Delete this queue as no more messages will be sent to this
// destination until it is no longer blacklisted.
oq.statistics.AssignBackoffNotifier(nil)
oq.queues.clearQueue(oq)
}
// handleTransactionSuccess updates the cached event queues as well as the success and
// backoff information for this server.
func (oq *destinationQueue) handleTransactionSuccess(pduCount int, eduCount int) {
// If we successfully sent the transaction then clear out
// the pending events and EDUs, and wipe our transaction ID.
oq.statistics.Success()
oq.pendingMutex.Lock()
defer oq.pendingMutex.Unlock()
for i := range oq.pendingPDUs[:pduCount] {
oq.pendingPDUs[i] = nil
}
for i := range oq.pendingEDUs[:eduCount] {
oq.pendingEDUs[i] = nil
}
oq.pendingPDUs = oq.pendingPDUs[pduCount:]
oq.pendingEDUs = oq.pendingEDUs[eduCount:]
if len(oq.pendingPDUs) > 0 || len(oq.pendingEDUs) > 0 {
select {
case oq.notify <- struct{}{}:
default:
}
} }
} }

View file

@ -24,6 +24,7 @@ import (
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
@ -171,14 +172,16 @@ func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *d
client: oqs.client, client: oqs.client,
statistics: oqs.statistics.ForServer(destination), statistics: oqs.statistics.ForServer(destination),
notify: make(chan struct{}, 1), notify: make(chan struct{}, 1),
interruptBackoff: make(chan bool),
signing: oqs.signing, signing: oqs.signing,
} }
oq.statistics.AssignBackoffNotifier(oq.handleBackoffNotifier)
oqs.queues[destination] = oq oqs.queues[destination] = oq
} }
return oq return oq
} }
// clearQueue removes the queue for the provided destination from the
// set of destination queues.
func (oqs *OutgoingQueues) clearQueue(oq *destinationQueue) { func (oqs *OutgoingQueues) clearQueue(oq *destinationQueue) {
oqs.queuesMutex.Lock() oqs.queuesMutex.Lock()
defer oqs.queuesMutex.Unlock() defer oqs.queuesMutex.Unlock()
@ -245,11 +248,25 @@ func (oqs *OutgoingQueues) SendEvent(
} }
for destination := range destmap { for destination := range destmap {
if queue := oqs.getQueue(destination); queue != nil { if queue := oqs.getQueue(destination); queue != nil && !queue.statistics.Blacklisted() {
queue.sendEvent(ev, nid) queue.sendEvent(ev, nid)
} else {
delete(destmap, destination)
} }
} }
// Create a database entry that associates the given PDU NID with
// this destinations queue. We'll then be able to retrieve the PDU
// later.
if err := oqs.db.AssociatePDUWithDestinations(
oqs.process.Context(),
destmap,
nid, // NIDs from federationapi_queue_json table
); err != nil {
logrus.WithError(err).Errorf("failed to associate PDUs %q with destinations", nid)
return err
}
return nil return nil
} }
@ -319,11 +336,27 @@ func (oqs *OutgoingQueues) SendEDU(
} }
for destination := range destmap { for destination := range destmap {
if queue := oqs.getQueue(destination); queue != nil { if queue := oqs.getQueue(destination); queue != nil && !queue.statistics.Blacklisted() {
queue.sendEDU(e, nid) queue.sendEDU(e, nid)
} else {
delete(destmap, destination)
} }
} }
// Create a database entry that associates the given PDU NID with
// this destination queue. We'll then be able to retrieve the PDU
// later.
if err := oqs.db.AssociateEDUWithDestinations(
oqs.process.Context(),
destmap, // the destination server name
nid, // NIDs from federationapi_queue_json table
e.Type,
nil, // this will use the default expireEDUTypes map
); err != nil {
logrus.WithError(err).Errorf("failed to associate EDU with destinations")
return err
}
return nil return nil
} }
@ -332,7 +365,9 @@ func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) {
if oqs.disabled { if oqs.disabled {
return return
} }
oqs.statistics.ForServer(srv).RemoveBlacklist()
if queue := oqs.getQueue(srv); queue != nil { if queue := oqs.getQueue(srv); queue != nil {
queue.statistics.ClearBackoff()
queue.wakeQueueIfNeeded() queue.wakeQueueIfNeeded()
} }
} }

File diff suppressed because it is too large Load diff

View file

@ -2,6 +2,7 @@ package statistics
import ( import (
"math" "math"
"math/rand"
"sync" "sync"
"time" "time"
@ -20,12 +21,23 @@ type Statistics struct {
servers map[gomatrixserverlib.ServerName]*ServerStatistics servers map[gomatrixserverlib.ServerName]*ServerStatistics
mutex sync.RWMutex mutex sync.RWMutex
backoffTimers map[gomatrixserverlib.ServerName]*time.Timer
backoffMutex sync.RWMutex
// How many times should we tolerate consecutive failures before we // How many times should we tolerate consecutive failures before we
// just blacklist the host altogether? The backoff is exponential, // just blacklist the host altogether? The backoff is exponential,
// so the max time here to attempt is 2**failures seconds. // so the max time here to attempt is 2**failures seconds.
FailuresUntilBlacklist uint32 FailuresUntilBlacklist uint32
} }
func NewStatistics(db storage.Database, failuresUntilBlacklist uint32) Statistics {
return Statistics{
DB: db,
FailuresUntilBlacklist: failuresUntilBlacklist,
backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer),
}
}
// ForServer returns server statistics for the given server name. If it // ForServer returns server statistics for the given server name. If it
// does not exist, it will create empty statistics and return those. // does not exist, it will create empty statistics and return those.
func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerStatistics { func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerStatistics {
@ -45,7 +57,6 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS
server = &ServerStatistics{ server = &ServerStatistics{
statistics: s, statistics: s,
serverName: serverName, serverName: serverName,
interrupt: make(chan struct{}),
} }
s.servers[serverName] = server s.servers[serverName] = server
s.mutex.Unlock() s.mutex.Unlock()
@ -70,23 +81,37 @@ type ServerStatistics struct {
backoffStarted atomic.Bool // is the backoff started backoffStarted atomic.Bool // is the backoff started
backoffUntil atomic.Value // time.Time until this backoff interval ends backoffUntil atomic.Value // time.Time until this backoff interval ends
backoffCount atomic.Uint32 // number of times BackoffDuration has been called backoffCount atomic.Uint32 // number of times BackoffDuration has been called
interrupt chan struct{} // interrupts the backoff goroutine
successCounter atomic.Uint32 // how many times have we succeeded? successCounter atomic.Uint32 // how many times have we succeeded?
backoffNotifier func() // notifies destination queue when backoff completes
notifierMutex sync.Mutex
} }
const maxJitterMultiplier = 1.4
const minJitterMultiplier = 0.8
// duration returns how long the next backoff interval should be. // duration returns how long the next backoff interval should be.
func (s *ServerStatistics) duration(count uint32) time.Duration { func (s *ServerStatistics) duration(count uint32) time.Duration {
return time.Second * time.Duration(math.Exp2(float64(count))) // Add some jitter to minimise the chance of having multiple backoffs
// ending at the same time.
jitter := rand.Float64()*(maxJitterMultiplier-minJitterMultiplier) + minJitterMultiplier
duration := time.Millisecond * time.Duration(math.Exp2(float64(count))*jitter*1000)
return duration
} }
// cancel will interrupt the currently active backoff. // cancel will interrupt the currently active backoff.
func (s *ServerStatistics) cancel() { func (s *ServerStatistics) cancel() {
s.blacklisted.Store(false) s.blacklisted.Store(false)
s.backoffUntil.Store(time.Time{}) s.backoffUntil.Store(time.Time{})
select {
case s.interrupt <- struct{}{}: s.ClearBackoff()
default:
} }
// AssignBackoffNotifier configures the channel to send to when
// a backoff completes.
func (s *ServerStatistics) AssignBackoffNotifier(notifier func()) {
s.notifierMutex.Lock()
defer s.notifierMutex.Unlock()
s.backoffNotifier = notifier
} }
// Success updates the server statistics with a new successful // Success updates the server statistics with a new successful
@ -95,8 +120,8 @@ func (s *ServerStatistics) cancel() {
// we will unblacklist it. // we will unblacklist it.
func (s *ServerStatistics) Success() { func (s *ServerStatistics) Success() {
s.cancel() s.cancel()
s.successCounter.Inc()
s.backoffCount.Store(0) s.backoffCount.Store(0)
s.successCounter.Inc()
if s.statistics.DB != nil { if s.statistics.DB != nil {
if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil {
logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName)
@ -105,13 +130,17 @@ func (s *ServerStatistics) Success() {
} }
// Failure marks a failure and starts backing off if needed. // Failure marks a failure and starts backing off if needed.
// The next call to BackoffIfRequired will do the right thing // It will return the time that the current failure
// after this. It will return the time that the current failure
// will result in backoff waiting until, and a bool signalling // will result in backoff waiting until, and a bool signalling
// whether we have blacklisted and therefore to give up. // whether we have blacklisted and therefore to give up.
func (s *ServerStatistics) Failure() (time.Time, bool) { func (s *ServerStatistics) Failure() (time.Time, bool) {
// Return immediately if we have blacklisted this node.
if s.blacklisted.Load() {
return time.Time{}, true
}
// If we aren't already backing off, this call will start // If we aren't already backing off, this call will start
// a new backoff period. Increase the failure counter and // a new backoff period, increase the failure counter and
// start a goroutine which will wait out the backoff and // start a goroutine which will wait out the backoff and
// unset the backoffStarted flag when done. // unset the backoffStarted flag when done.
if s.backoffStarted.CompareAndSwap(false, true) { if s.backoffStarted.CompareAndSwap(false, true) {
@ -122,40 +151,48 @@ func (s *ServerStatistics) Failure() (time.Time, bool) {
logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName) logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName)
} }
} }
s.ClearBackoff()
return time.Time{}, true return time.Time{}, true
} }
go func() { // We're starting a new back off so work out what the next interval
until, ok := s.backoffUntil.Load().(time.Time)
if ok && !until.IsZero() {
select {
case <-time.After(time.Until(until)):
case <-s.interrupt:
}
s.backoffStarted.Store(false)
}
}()
}
// Check if we have blacklisted this node.
if s.blacklisted.Load() {
return time.Now(), true
}
// If we're already backing off and we haven't yet surpassed
// the deadline then return that. Repeated calls to Failure
// within a single backoff interval will have no side effects.
if until, ok := s.backoffUntil.Load().(time.Time); ok && !time.Now().After(until) {
return until, false
}
// We're either backing off and have passed the deadline, or
// we aren't backing off, so work out what the next interval
// will be. // will be.
count := s.backoffCount.Load() count := s.backoffCount.Load()
until := time.Now().Add(s.duration(count)) until := time.Now().Add(s.duration(count))
s.backoffUntil.Store(until) s.backoffUntil.Store(until)
return until, false
s.statistics.backoffMutex.Lock()
defer s.statistics.backoffMutex.Unlock()
s.statistics.backoffTimers[s.serverName] = time.AfterFunc(time.Until(until), s.backoffFinished)
}
return s.backoffUntil.Load().(time.Time), false
}
// ClearBackoff stops the backoff timer for this destination if it is running
// and removes the timer from the backoffTimers map.
func (s *ServerStatistics) ClearBackoff() {
// If the timer is still running then stop it so it's memory is cleaned up sooner.
s.statistics.backoffMutex.Lock()
defer s.statistics.backoffMutex.Unlock()
if timer, ok := s.statistics.backoffTimers[s.serverName]; ok {
timer.Stop()
}
delete(s.statistics.backoffTimers, s.serverName)
s.backoffStarted.Store(false)
}
// backoffFinished will clear the previous backoff and notify the destination queue.
func (s *ServerStatistics) backoffFinished() {
s.ClearBackoff()
// Notify the destinationQueue if one is currently running.
s.notifierMutex.Lock()
defer s.notifierMutex.Unlock()
if s.backoffNotifier != nil {
s.backoffNotifier()
}
} }
// BackoffInfo returns information about the current or previous backoff. // BackoffInfo returns information about the current or previous backoff.
@ -174,6 +211,12 @@ func (s *ServerStatistics) Blacklisted() bool {
return s.blacklisted.Load() return s.blacklisted.Load()
} }
// RemoveBlacklist removes the blacklisted status from the server.
func (s *ServerStatistics) RemoveBlacklist() {
s.cancel()
s.backoffCount.Store(0)
}
// SuccessCount returns the number of successful requests. This is // SuccessCount returns the number of successful requests. This is
// usually useful in constructing transaction IDs. // usually useful in constructing transaction IDs.
func (s *ServerStatistics) SuccessCount() uint32 { func (s *ServerStatistics) SuccessCount() uint32 {

View file

@ -7,9 +7,7 @@ import (
) )
func TestBackoff(t *testing.T) { func TestBackoff(t *testing.T) {
stats := Statistics{ stats := NewStatistics(nil, 7)
FailuresUntilBlacklist: 7,
}
server := ServerStatistics{ server := ServerStatistics{
statistics: &stats, statistics: &stats,
serverName: "test.com", serverName: "test.com",
@ -36,7 +34,7 @@ func TestBackoff(t *testing.T) {
// Get the duration. // Get the duration.
_, blacklist := server.BackoffInfo() _, blacklist := server.BackoffInfo()
duration := time.Until(until).Round(time.Second) duration := time.Until(until)
// Unset the backoff, or otherwise our next call will think that // Unset the backoff, or otherwise our next call will think that
// there's a backoff in progress and return the same result. // there's a backoff in progress and return the same result.
@ -57,8 +55,17 @@ func TestBackoff(t *testing.T) {
// Check if the duration is what we expect. // Check if the duration is what we expect.
t.Logf("Backoff %d is for %s", i, duration) t.Logf("Backoff %d is for %s", i, duration)
if wanted := time.Second * time.Duration(math.Exp2(float64(i))); !blacklist && duration != wanted { roundingAllowance := 0.01
t.Fatalf("Backoff %d should have been %s but was %s", i, wanted, duration) minDuration := time.Millisecond * time.Duration(math.Exp2(float64(i))*minJitterMultiplier*1000-roundingAllowance)
maxDuration := time.Millisecond * time.Duration(math.Exp2(float64(i))*maxJitterMultiplier*1000+roundingAllowance)
var inJitterRange bool
if duration >= minDuration && duration <= maxDuration {
inJitterRange = true
} else {
inJitterRange = false
}
if !blacklist && !inJitterRange {
t.Fatalf("Backoff %d should have been between %s and %s but was %s", i, minDuration, maxDuration, duration)
} }
} }
} }

View file

@ -18,9 +18,10 @@ import (
"context" "context"
"time" "time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/federationapi/types"
"github.com/matrix-org/gomatrixserverlib"
) )
type Database interface { type Database interface {
@ -38,8 +39,8 @@ type Database interface {
GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error)
GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error)
AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error
AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error
CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error

View file

@ -52,6 +52,10 @@ type Receipt struct {
nid int64 nid int64
} }
func NewReceipt(nid int64) Receipt {
return Receipt{nid: nid}
}
func (r *Receipt) String() string { func (r *Receipt) String() string {
return fmt.Sprintf("%d", r.nid) return fmt.Sprintf("%d", r.nid)
} }

View file

@ -38,9 +38,9 @@ var defaultExpireEDUTypes = map[string]time.Duration{
// AssociateEDUWithDestination creates an association that the // AssociateEDUWithDestination creates an association that the
// destination queues will use to determine which JSON blobs to send // destination queues will use to determine which JSON blobs to send
// to which servers. // to which servers.
func (d *Database) AssociateEDUWithDestination( func (d *Database) AssociateEDUWithDestinations(
ctx context.Context, ctx context.Context,
serverName gomatrixserverlib.ServerName, destinations map[gomatrixserverlib.ServerName]struct{},
receipt *Receipt, receipt *Receipt,
eduType string, eduType string,
expireEDUTypes map[string]time.Duration, expireEDUTypes map[string]time.Duration,
@ -59,17 +59,18 @@ func (d *Database) AssociateEDUWithDestination(
expiresAt = 0 expiresAt = 0
} }
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.FederationQueueEDUs.InsertQueueEDU( var err error
for destination := range destinations {
err = d.FederationQueueEDUs.InsertQueueEDU(
ctx, // context ctx, // context
txn, // SQL transaction txn, // SQL transaction
eduType, // EDU type for coalescing eduType, // EDU type for coalescing
serverName, // destination server name destination, // destination server name
receipt.nid, // NID from the federationapi_queue_json table receipt.nid, // NID from the federationapi_queue_json table
expiresAt, // The timestamp this EDU will expire expiresAt, // The timestamp this EDU will expire
); err != nil { )
return fmt.Errorf("InsertQueueEDU: %w", err)
} }
return nil return err
}) })
} }

View file

@ -27,23 +27,23 @@ import (
// AssociatePDUWithDestination creates an association that the // AssociatePDUWithDestination creates an association that the
// destination queues will use to determine which JSON blobs to send // destination queues will use to determine which JSON blobs to send
// to which servers. // to which servers.
func (d *Database) AssociatePDUWithDestination( func (d *Database) AssociatePDUWithDestinations(
ctx context.Context, ctx context.Context,
transactionID gomatrixserverlib.TransactionID, destinations map[gomatrixserverlib.ServerName]struct{},
serverName gomatrixserverlib.ServerName,
receipt *Receipt, receipt *Receipt,
) error { ) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.FederationQueuePDUs.InsertQueuePDU( var err error
for destination := range destinations {
err = d.FederationQueuePDUs.InsertQueuePDU(
ctx, // context ctx, // context
txn, // SQL transaction txn, // SQL transaction
transactionID, // transaction ID "", // transaction ID
serverName, // destination server name destination, // destination server name
receipt.nid, // NID from the federationapi_queue_json table receipt.nid, // NID from the federationapi_queue_json table
); err != nil { )
return fmt.Errorf("InsertQueuePDU: %w", err)
} }
return nil return err
}) })
} }

View file

@ -35,6 +35,7 @@ func TestExpireEDUs(t *testing.T) {
} }
ctx := context.Background() ctx := context.Background()
destinations := map[gomatrixserverlib.ServerName]struct{}{"localhost": {}}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateFederationDatabase(t, dbType) db, close := mustCreateFederationDatabase(t, dbType)
defer close() defer close()
@ -43,7 +44,7 @@ func TestExpireEDUs(t *testing.T) {
receipt, err := db.StoreJSON(ctx, "{}") receipt, err := db.StoreJSON(ctx, "{}")
assert.NoError(t, err) assert.NoError(t, err)
err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, gomatrixserverlib.MReceipt, expireEDUTypes) err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, gomatrixserverlib.MReceipt, expireEDUTypes)
assert.NoError(t, err) assert.NoError(t, err)
} }
// add data without expiry // add data without expiry
@ -51,7 +52,7 @@ func TestExpireEDUs(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
// m.read_marker gets the default expiry of 24h, so won't be deleted further down in this test // m.read_marker gets the default expiry of 24h, so won't be deleted further down in this test
err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, "m.read_marker", expireEDUTypes) err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, "m.read_marker", expireEDUTypes)
assert.NoError(t, err) assert.NoError(t, err)
// Delete expired EDUs // Delete expired EDUs
@ -67,7 +68,7 @@ func TestExpireEDUs(t *testing.T) {
receipt, err = db.StoreJSON(ctx, "{}") receipt, err = db.StoreJSON(ctx, "{}")
assert.NoError(t, err) assert.NoError(t, err)
err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, gomatrixserverlib.MDirectToDevice, expireEDUTypes) err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, gomatrixserverlib.MDirectToDevice, expireEDUTypes)
assert.NoError(t, err) assert.NoError(t, err)
err = db.DeleteExpiredEDUs(ctx) err = db.DeleteExpiredEDUs(ctx)

4
go.mod
View file

@ -22,7 +22,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20221014061925-a132619fa241 github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a
github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3 github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.15 github.com/mattn/go-sqlite3 v1.14.15
@ -50,6 +50,7 @@ require (
golang.org/x/term v0.0.0-20220919170432-7a66f970e087 golang.org/x/term v0.0.0-20220919170432-7a66f970e087
gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/h2non/bimg.v1 v1.1.9
gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v2 v2.4.0
gotest.tools/v3 v3.4.0
nhooyr.io/websocket v1.8.7 nhooyr.io/websocket v1.8.7
) )
@ -129,7 +130,6 @@ require (
gopkg.in/macaroon.v2 v2.1.0 // indirect gopkg.in/macaroon.v2 v2.1.0 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
gotest.tools/v3 v3.4.0 // indirect
) )
go 1.18 go 1.18

4
go.sum
View file

@ -385,8 +385,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20221014061925-a132619fa241 h1:e5o68MWeU7wjTvvNKmVo655oCYesoNRoPeBb1Xfz54g= github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a h1:6rJFN5NBuzZ7h5meYkLtXKa6VFZfDc8oVXHd4SDXr5o=
github.com/matrix-org/gomatrixserverlib v0.0.0-20221014061925-a132619fa241/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4=
github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3 h1:lzkSQvBv8TuqKJCPoVwOVvEnARTlua5rrNy/Qw2Vxeo= github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3 h1:lzkSQvBv8TuqKJCPoVwOVvEnARTlua5rrNy/Qw2Vxeo=
github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k= github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=

View file

@ -13,6 +13,8 @@
package transactions package transactions
import ( import (
"net/url"
"path/filepath"
"sync" "sync"
"time" "time"
@ -29,6 +31,7 @@ type txnsMap map[CacheKey]*util.JSONResponse
type CacheKey struct { type CacheKey struct {
AccessToken string AccessToken string
TxnID string TxnID string
Endpoint string
} }
// Cache represents a temporary store for response entries. // Cache represents a temporary store for response entries.
@ -57,14 +60,14 @@ func NewWithCleanupPeriod(cleanupPeriod time.Duration) *Cache {
return &t return &t
} }
// FetchTransaction looks up an entry for the (accessToken, txnID) tuple in Cache. // FetchTransaction looks up an entry for the (accessToken, txnID, req.URL) tuple in Cache.
// Looks in both the txnMaps. // Looks in both the txnMaps.
// Returns (JSON response, true) if txnID is found, else the returned bool is false. // Returns (JSON response, true) if txnID is found, else the returned bool is false.
func (t *Cache) FetchTransaction(accessToken, txnID string) (*util.JSONResponse, bool) { func (t *Cache) FetchTransaction(accessToken, txnID string, u *url.URL) (*util.JSONResponse, bool) {
t.RLock() t.RLock()
defer t.RUnlock() defer t.RUnlock()
for _, txns := range t.txnsMaps { for _, txns := range t.txnsMaps {
res, ok := txns[CacheKey{accessToken, txnID}] res, ok := txns[CacheKey{accessToken, txnID, filepath.Dir(u.Path)}]
if ok { if ok {
return res, true return res, true
} }
@ -72,13 +75,12 @@ func (t *Cache) FetchTransaction(accessToken, txnID string) (*util.JSONResponse,
return nil, false return nil, false
} }
// AddTransaction adds an entry for the (accessToken, txnID) tuple in Cache. // AddTransaction adds an entry for the (accessToken, txnID, req.URL) tuple in Cache.
// Adds to the front txnMap. // Adds to the front txnMap.
func (t *Cache) AddTransaction(accessToken, txnID string, res *util.JSONResponse) { func (t *Cache) AddTransaction(accessToken, txnID string, u *url.URL, res *util.JSONResponse) {
t.Lock() t.Lock()
defer t.Unlock() defer t.Unlock()
t.txnsMaps[0][CacheKey{accessToken, txnID, filepath.Dir(u.Path)}] = res
t.txnsMaps[0][CacheKey{accessToken, txnID}] = res
} }
// cacheCleanService is responsible for cleaning up entries after cleanupPeriod. // cacheCleanService is responsible for cleaning up entries after cleanupPeriod.

View file

@ -14,6 +14,9 @@ package transactions
import ( import (
"net/http" "net/http"
"net/url"
"path/filepath"
"reflect"
"strconv" "strconv"
"testing" "testing"
@ -24,6 +27,16 @@ type fakeType struct {
ID string `json:"ID"` ID string `json:"ID"`
} }
func TestCompare(t *testing.T) {
u1, _ := url.Parse("/send/1?accessToken=123")
u2, _ := url.Parse("/send/1")
c1 := CacheKey{"1", "2", filepath.Dir(u1.Path)}
c2 := CacheKey{"1", "2", filepath.Dir(u2.Path)}
if !reflect.DeepEqual(c1, c2) {
t.Fatalf("Cache keys differ: %+v <> %+v", c1, c2)
}
}
var ( var (
fakeAccessToken = "aRandomAccessToken" fakeAccessToken = "aRandomAccessToken"
fakeAccessToken2 = "anotherRandomAccessToken" fakeAccessToken2 = "anotherRandomAccessToken"
@ -34,23 +47,28 @@ var (
fakeResponse2 = &util.JSONResponse{ fakeResponse2 = &util.JSONResponse{
Code: http.StatusOK, JSON: fakeType{ID: "1"}, Code: http.StatusOK, JSON: fakeType{ID: "1"},
} }
fakeResponse3 = &util.JSONResponse{
Code: http.StatusOK, JSON: fakeType{ID: "2"},
}
) )
// TestCache creates a New Cache and tests AddTransaction & FetchTransaction // TestCache creates a New Cache and tests AddTransaction & FetchTransaction
func TestCache(t *testing.T) { func TestCache(t *testing.T) {
fakeTxnCache := New() fakeTxnCache := New()
fakeTxnCache.AddTransaction(fakeAccessToken, fakeTxnID, fakeResponse) u, _ := url.Parse("")
fakeTxnCache.AddTransaction(fakeAccessToken, fakeTxnID, u, fakeResponse)
// Add entries for noise. // Add entries for noise.
for i := 1; i <= 100; i++ { for i := 1; i <= 100; i++ {
fakeTxnCache.AddTransaction( fakeTxnCache.AddTransaction(
fakeAccessToken, fakeAccessToken,
fakeTxnID+strconv.Itoa(i), fakeTxnID+strconv.Itoa(i),
u,
&util.JSONResponse{Code: http.StatusOK, JSON: fakeType{ID: strconv.Itoa(i)}}, &util.JSONResponse{Code: http.StatusOK, JSON: fakeType{ID: strconv.Itoa(i)}},
) )
} }
testResponse, ok := fakeTxnCache.FetchTransaction(fakeAccessToken, fakeTxnID) testResponse, ok := fakeTxnCache.FetchTransaction(fakeAccessToken, fakeTxnID, u)
if !ok { if !ok {
t.Error("Failed to retrieve entry for txnID: ", fakeTxnID) t.Error("Failed to retrieve entry for txnID: ", fakeTxnID)
} else if testResponse.JSON != fakeResponse.JSON { } else if testResponse.JSON != fakeResponse.JSON {
@ -59,20 +77,30 @@ func TestCache(t *testing.T) {
} }
// TestCacheScope ensures transactions with the same transaction ID are not shared // TestCacheScope ensures transactions with the same transaction ID are not shared
// across multiple access tokens. // across multiple access tokens and endpoints.
func TestCacheScope(t *testing.T) { func TestCacheScope(t *testing.T) {
cache := New() cache := New()
cache.AddTransaction(fakeAccessToken, fakeTxnID, fakeResponse) sendEndpoint, _ := url.Parse("/send/1?accessToken=test")
cache.AddTransaction(fakeAccessToken2, fakeTxnID, fakeResponse2) sendToDeviceEndpoint, _ := url.Parse("/sendToDevice/1")
cache.AddTransaction(fakeAccessToken, fakeTxnID, sendEndpoint, fakeResponse)
cache.AddTransaction(fakeAccessToken2, fakeTxnID, sendEndpoint, fakeResponse2)
cache.AddTransaction(fakeAccessToken2, fakeTxnID, sendToDeviceEndpoint, fakeResponse3)
if res, ok := cache.FetchTransaction(fakeAccessToken, fakeTxnID); !ok { if res, ok := cache.FetchTransaction(fakeAccessToken, fakeTxnID, sendEndpoint); !ok {
t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID) t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID)
} else if res.JSON != fakeResponse.JSON { } else if res.JSON != fakeResponse.JSON {
t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse.JSON, res.JSON) t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse.JSON, res.JSON)
} }
if res, ok := cache.FetchTransaction(fakeAccessToken2, fakeTxnID); !ok { if res, ok := cache.FetchTransaction(fakeAccessToken2, fakeTxnID, sendEndpoint); !ok {
t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID) t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID)
} else if res.JSON != fakeResponse2.JSON { } else if res.JSON != fakeResponse2.JSON {
t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse2.JSON, res.JSON) t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse2.JSON, res.JSON)
} }
// Ensure the txnID is not shared across endpoints
if res, ok := cache.FetchTransaction(fakeAccessToken2, fakeTxnID, sendToDeviceEndpoint); !ok {
t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID)
} else if res.JSON != fakeResponse3.JSON {
t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse2.JSON, res.JSON)
}
} }

View file

@ -17,7 +17,7 @@ var build string
const ( const (
VersionMajor = 0 VersionMajor = 0
VersionMinor = 10 VersionMinor = 10
VersionPatch = 3 VersionPatch = 4
VersionTag = "" // example: "rc1" VersionTag = "" // example: "rc1"
) )

View file

@ -250,6 +250,7 @@ func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *ap
// nolint:gocyclo // nolint:gocyclo
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error { func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error {
var respMu sync.Mutex
res.DeviceKeys = make(map[string]map[string]json.RawMessage) res.DeviceKeys = make(map[string]map[string]json.RawMessage)
res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey) res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey) res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
@ -329,7 +330,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
} }
// attempt to satisfy key queries from the local database first as we should get device updates pushed to us // attempt to satisfy key queries from the local database first as we should get device updates pushed to us
domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys) domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, &respMu, domainToDeviceKeys)
if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 { if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 {
// perform key queries for remote devices // perform key queries for remote devices
a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys) a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys)
@ -407,7 +408,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
} }
func (a *KeyInternalAPI) remoteKeysFromDatabase( func (a *KeyInternalAPI) remoteKeysFromDatabase(
ctx context.Context, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string, ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string,
) map[string]map[string][]string { ) map[string]map[string][]string {
fetchRemote := make(map[string]map[string][]string) fetchRemote := make(map[string]map[string][]string)
for domain, userToDeviceMap := range domainToDeviceKeys { for domain, userToDeviceMap := range domainToDeviceKeys {
@ -415,7 +416,7 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase(
// we can't safely return keys from the db when all devices are requested as we don't // we can't safely return keys from the db when all devices are requested as we don't
// know if one has just been added. // know if one has just been added.
if len(deviceIDs) > 0 { if len(deviceIDs) > 0 {
err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, deviceIDs) err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, deviceIDs)
if err == nil { if err == nil {
continue continue
} }
@ -471,7 +472,9 @@ func (a *KeyInternalAPI) queryRemoteKeys(
close(resultCh) close(resultCh)
}() }()
for result := range resultCh { processResult := func(result *gomatrixserverlib.RespQueryKeys) {
respMu.Lock()
defer respMu.Unlock()
for userID, nest := range result.DeviceKeys { for userID, nest := range result.DeviceKeys {
res.DeviceKeys[userID] = make(map[string]json.RawMessage) res.DeviceKeys[userID] = make(map[string]json.RawMessage)
for deviceID, deviceKey := range nest { for deviceID, deviceKey := range nest {
@ -494,6 +497,10 @@ func (a *KeyInternalAPI) queryRemoteKeys(
// TODO: do we want to persist these somewhere now // TODO: do we want to persist these somewhere now
// that we have fetched them? // that we have fetched them?
} }
for result := range resultCh {
processResult(result)
}
} }
func (a *KeyInternalAPI) queryRemoteKeysOnServer( func (a *KeyInternalAPI) queryRemoteKeysOnServer(
@ -541,9 +548,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
} }
// refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this // refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this
// user so the fact that we're populating all devices here isn't a problem so long as we have devices. // user so the fact that we're populating all devices here isn't a problem so long as we have devices.
respMu.Lock() err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil)
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, nil)
respMu.Unlock()
if err != nil { if err != nil {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
logrus.ErrorKey: err, logrus.ErrorKey: err,
@ -567,25 +572,26 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
res.Failures[serverName] = map[string]interface{}{ res.Failures[serverName] = map[string]interface{}{
"message": err.Error(), "message": err.Error(),
} }
respMu.Unlock()
// last ditch, use the cache only. This is good for when clients hit /keys/query and the remote server // last ditch, use the cache only. This is good for when clients hit /keys/query and the remote server
// is down, better to return something than nothing at all. Clients can know about the failure by // is down, better to return something than nothing at all. Clients can know about the failure by
// inspecting the failures map though so they can know it's a cached response. // inspecting the failures map though so they can know it's a cached response.
for userID, dkeys := range devKeys { for userID, dkeys := range devKeys {
// drop the error as it's already a failure at this point // drop the error as it's already a failure at this point
_ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, dkeys) _ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, dkeys)
} }
// Sytest expects no failures, if we still could retrieve keys, e.g. from local cache // Sytest expects no failures, if we still could retrieve keys, e.g. from local cache
respMu.Lock()
if len(res.DeviceKeys) > 0 { if len(res.DeviceKeys) > 0 {
delete(res.Failures, serverName) delete(res.Failures, serverName)
} }
respMu.Unlock() respMu.Unlock()
} }
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string, ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string,
) error { ) error {
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
// if we can't query the db or there are fewer keys than requested, fetch from remote. // if we can't query the db or there are fewer keys than requested, fetch from remote.
@ -598,9 +604,11 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
if len(deviceIDs) == 0 && len(keys) == 0 { if len(deviceIDs) == 0 && len(keys) == 0 {
return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID) return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID)
} }
respMu.Lock()
if res.DeviceKeys[userID] == nil { if res.DeviceKeys[userID] == nil {
res.DeviceKeys[userID] = make(map[string]json.RawMessage) res.DeviceKeys[userID] = make(map[string]json.RawMessage)
} }
respMu.Unlock()
for _, key := range keys { for _, key := range keys {
if len(key.KeyJSON) == 0 { if len(key.KeyJSON) == 0 {
@ -610,7 +618,9 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct { key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
DisplayName string `json:"device_display_name,omitempty"` DisplayName string `json:"device_display_name,omitempty"`
}{key.DisplayName}) }{key.DisplayName})
respMu.Lock()
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
respMu.Unlock()
} }
return nil return nil
} }

View file

@ -428,6 +428,13 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
return return
} }
// Only notify clients about retired invite events, if the user didn't accept the invite.
// The PDU stream will also receive an event about accepting the invitation, so there should
// be a "smooth" transition from invite -> join, and not invite -> leave -> join
if msg.Membership == gomatrixserverlib.Join {
return
}
// Notify any active sync requests that the invite has been retired. // Notify any active sync requests that the invite has been retired.
s.inviteStream.Advance(pduPos) s.inviteStream.Advance(pduPos)
s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, msg.TargetUserID) s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, msg.TargetUserID)

View file

@ -28,8 +28,9 @@ import (
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
const outputRoomEventsSchema = ` const outputRoomEventsSchema = `
@ -133,7 +134,7 @@ const updateEventJSONSQL = "" +
"UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2" "UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2"
// In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id). // In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id).
const selectStateInRangeSQL = "" + const selectStateInRangeFilteredSQL = "" +
"SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids, history_visibility" + "SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids, history_visibility" +
" FROM syncapi_output_room_events" + " FROM syncapi_output_room_events" +
" WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" + " WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" +
@ -146,6 +147,15 @@ const selectStateInRangeSQL = "" +
" ORDER BY id ASC" + " ORDER BY id ASC" +
" LIMIT $9" " LIMIT $9"
// In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id).
const selectStateInRangeSQL = "" +
"SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids, history_visibility" +
" FROM syncapi_output_room_events" +
" WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" +
" AND room_id = ANY($3)" +
" ORDER BY id ASC" +
" LIMIT $4"
const deleteEventsForRoomSQL = "" + const deleteEventsForRoomSQL = "" +
"DELETE FROM syncapi_output_room_events WHERE room_id = $1" "DELETE FROM syncapi_output_room_events WHERE room_id = $1"
@ -178,6 +188,7 @@ type outputRoomEventsStatements struct {
selectRecentEventsStmt *sql.Stmt selectRecentEventsStmt *sql.Stmt
selectRecentEventsForSyncStmt *sql.Stmt selectRecentEventsForSyncStmt *sql.Stmt
selectEarlyEventsStmt *sql.Stmt selectEarlyEventsStmt *sql.Stmt
selectStateInRangeFilteredStmt *sql.Stmt
selectStateInRangeStmt *sql.Stmt selectStateInRangeStmt *sql.Stmt
updateEventJSONStmt *sql.Stmt updateEventJSONStmt *sql.Stmt
deleteEventsForRoomStmt *sql.Stmt deleteEventsForRoomStmt *sql.Stmt
@ -214,6 +225,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
{&s.selectRecentEventsStmt, selectRecentEventsSQL}, {&s.selectRecentEventsStmt, selectRecentEventsSQL},
{&s.selectRecentEventsForSyncStmt, selectRecentEventsForSyncSQL}, {&s.selectRecentEventsForSyncStmt, selectRecentEventsForSyncSQL},
{&s.selectEarlyEventsStmt, selectEarlyEventsSQL}, {&s.selectEarlyEventsStmt, selectEarlyEventsSQL},
{&s.selectStateInRangeFilteredStmt, selectStateInRangeFilteredSQL},
{&s.selectStateInRangeStmt, selectStateInRangeSQL}, {&s.selectStateInRangeStmt, selectStateInRangeSQL},
{&s.updateEventJSONStmt, updateEventJSONSQL}, {&s.updateEventJSONStmt, updateEventJSONSQL},
{&s.deleteEventsForRoomStmt, deleteEventsForRoomSQL}, {&s.deleteEventsForRoomStmt, deleteEventsForRoomSQL},
@ -240,9 +252,12 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
ctx context.Context, txn *sql.Tx, r types.Range, ctx context.Context, txn *sql.Tx, r types.Range,
stateFilter *gomatrixserverlib.StateFilter, roomIDs []string, stateFilter *gomatrixserverlib.StateFilter, roomIDs []string,
) (map[string]map[string]bool, map[string]types.StreamEvent, error) { ) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt) var rows *sql.Rows
var err error
if stateFilter != nil {
stmt := sqlutil.TxStmt(txn, s.selectStateInRangeFilteredStmt)
senders, notSenders := getSendersStateFilterFilter(stateFilter) senders, notSenders := getSendersStateFilterFilter(stateFilter)
rows, err := stmt.QueryContext( rows, err = stmt.QueryContext(
ctx, r.Low(), r.High(), pq.StringArray(roomIDs), ctx, r.Low(), r.High(), pq.StringArray(roomIDs),
pq.StringArray(senders), pq.StringArray(senders),
pq.StringArray(notSenders), pq.StringArray(notSenders),
@ -251,6 +266,14 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
stateFilter.ContainsURL, stateFilter.ContainsURL,
stateFilter.Limit, stateFilter.Limit,
) )
} else {
stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt)
rows, err = stmt.QueryContext(
ctx, r.Low(), r.High(), pq.StringArray(roomIDs),
r.High()-r.Low(),
)
}
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View file

@ -5,10 +5,11 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
) )
type DatabaseTransaction struct { type DatabaseTransaction struct {
@ -277,6 +278,7 @@ func (d *DatabaseTransaction) GetBackwardTopologyPos(
// exclusive of oldPos, inclusive of newPos, for the rooms in which // exclusive of oldPos, inclusive of newPos, for the rooms in which
// the user has new membership events. // the user has new membership events.
// A list of joined room IDs is also returned in case the caller needs it. // A list of joined room IDs is also returned in case the caller needs it.
// nolint:gocyclo
func (d *DatabaseTransaction) GetStateDeltas( func (d *DatabaseTransaction) GetStateDeltas(
ctx context.Context, device *userapi.Device, ctx context.Context, device *userapi.Device,
r types.Range, userID string, r types.Range, userID string,
@ -311,7 +313,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
} }
// get all the state events ever (i.e. for all available rooms) between these two positions // get all the state events ever (i.e. for all available rooms) between these two positions
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs) stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, nil, allRoomIDs)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil, nil return nil, nil, nil
@ -326,6 +328,22 @@ func (d *DatabaseTransaction) GetStateDeltas(
return nil, nil, err return nil, nil, err
} }
// get all the state events ever (i.e. for all available rooms) between these two positions
stateNeededFiltered, eventMapFiltered, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
}
return nil, nil, err
}
stateFiltered, err := d.fetchStateEvents(ctx, d.txn, stateNeededFiltered, eventMapFiltered)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
}
return nil, nil, err
}
// find out which rooms this user is peeking, if any. // find out which rooms this user is peeking, if any.
// We do this before joins so any peeks get overwritten // We do this before joins so any peeks get overwritten
peeks, err := d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, device.ID, r) peeks, err := d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, device.ID, r)
@ -371,6 +389,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
// If our membership is now join but the previous membership wasn't // If our membership is now join but the previous membership wasn't
// then this is a "join transition", so we'll insert this room. // then this is a "join transition", so we'll insert this room.
if prevMembership != membership { if prevMembership != membership {
newlyJoinedRooms[roomID] = true
// Get the full room state, as we'll send that down for a newly // Get the full room state, as we'll send that down for a newly
// joined room instead of a delta. // joined room instead of a delta.
var s []types.StreamEvent var s []types.StreamEvent
@ -383,8 +402,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
// Add the information for this room into the state so that // Add the information for this room into the state so that
// it will get added with all of the rest of the joined rooms. // it will get added with all of the rest of the joined rooms.
state[roomID] = s stateFiltered[roomID] = s
newlyJoinedRooms[roomID] = true
} }
// We won't add joined rooms into the delta at this point as they // We won't add joined rooms into the delta at this point as they
@ -395,7 +413,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
deltas = append(deltas, types.StateDelta{ deltas = append(deltas, types.StateDelta{
Membership: membership, Membership: membership,
MembershipPos: ev.StreamPosition, MembershipPos: ev.StreamPosition,
StateEvents: d.StreamEventsToEvents(device, stateStreamEvents), StateEvents: d.StreamEventsToEvents(device, stateFiltered[roomID]),
RoomID: roomID, RoomID: roomID,
}) })
break break
@ -407,7 +425,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
for _, joinedRoomID := range joinedRoomIDs { for _, joinedRoomID := range joinedRoomIDs {
deltas = append(deltas, types.StateDelta{ deltas = append(deltas, types.StateDelta{
Membership: gomatrixserverlib.Join, Membership: gomatrixserverlib.Join,
StateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]), StateEvents: d.StreamEventsToEvents(device, stateFiltered[joinedRoomID]),
RoomID: joinedRoomID, RoomID: joinedRoomID,
NewlyJoined: newlyJoinedRooms[joinedRoomID], NewlyJoined: newlyJoinedRooms[joinedRoomID],
}) })

View file

@ -29,8 +29,9 @@ import (
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
const outputRoomEventsSchema = ` const outputRoomEventsSchema = `
@ -189,21 +190,36 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
for _, roomID := range roomIDs { for _, roomID := range roomIDs {
inputParams = append(inputParams, roomID) inputParams = append(inputParams, roomID)
} }
stmt, params, err := prepareWithFilters( var (
stmt *sql.Stmt
params []any
err error
)
if stateFilter != nil {
stmt, params, err = prepareWithFilters(
s.db, txn, stmtSQL, inputParams, s.db, txn, stmtSQL, inputParams,
stateFilter.Senders, stateFilter.NotSenders, stateFilter.Senders, stateFilter.NotSenders,
stateFilter.Types, stateFilter.NotTypes, stateFilter.Types, stateFilter.NotTypes,
nil, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderAsc, nil, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderAsc,
) )
} else {
stmt, params, err = prepareWithFilters(
s.db, txn, stmtSQL, inputParams,
nil, nil,
nil, nil,
nil, nil, int(r.High()-r.Low()), FilterOrderAsc,
)
}
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err) return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err)
} }
defer internal.CloseAndLogIfError(ctx, stmt, "selectStateInRange: stmt.close() failed")
rows, err := stmt.QueryContext(ctx, params...) rows, err := stmt.QueryContext(ctx, params...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
defer rows.Close() // nolint: errcheck defer internal.CloseAndLogIfError(ctx, rows, "selectStateInRange: rows.close() failed")
// Fetch all the state change events for all rooms between the two positions then loop each event and: // Fetch all the state change events for all rooms between the two positions then loop each event and:
// - Keep a cache of the event by ID (99% of state change events are for the event itself) // - Keep a cache of the event by ID (99% of state change events are for the event itself)
// - For each room ID, build up an array of event IDs which represents cumulative adds/removes // - For each room ID, build up an array of event IDs which represents cumulative adds/removes
@ -269,6 +285,7 @@ func (s *outputRoomEventsStatements) SelectMaxEventID(
) (id int64, err error) { ) (id int64, err error) {
var nullableID sql.NullInt64 var nullableID sql.NullInt64
stmt := sqlutil.TxStmt(txn, s.selectMaxEventIDStmt) stmt := sqlutil.TxStmt(txn, s.selectMaxEventIDStmt)
defer internal.CloseAndLogIfError(ctx, stmt, "SelectMaxEventID: stmt.close() failed")
err = stmt.QueryRowContext(ctx).Scan(&nullableID) err = stmt.QueryRowContext(ctx).Scan(&nullableID)
if nullableID.Valid { if nullableID.Valid {
id = nullableID.Int64 id = nullableID.Int64
@ -323,6 +340,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
return 0, err return 0, err
} }
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
defer internal.CloseAndLogIfError(ctx, insertStmt, "InsertEvent: stmt.close() failed")
_, err = insertStmt.ExecContext( _, err = insertStmt.ExecContext(
ctx, ctx,
streamPos, streamPos,
@ -367,6 +385,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
if err != nil { if err != nil {
return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err) return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err)
} }
defer internal.CloseAndLogIfError(ctx, stmt, "selectRecentEvents: stmt.close() failed")
rows, err := stmt.QueryContext(ctx, params...) rows, err := stmt.QueryContext(ctx, params...)
if err != nil { if err != nil {
@ -415,6 +434,8 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
if err != nil { if err != nil {
return nil, fmt.Errorf("s.prepareWithFilters: %w", err) return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
} }
defer internal.CloseAndLogIfError(ctx, stmt, "SelectEarlyEvents: stmt.close() failed")
rows, err := stmt.QueryContext(ctx, params...) rows, err := stmt.QueryContext(ctx, params...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -456,6 +477,8 @@ func (s *outputRoomEventsStatements) SelectEvents(
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, stmt, "SelectEvents: stmt.close() failed")
rows, err := stmt.QueryContext(ctx, params...) rows, err := stmt.QueryContext(ctx, params...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -558,6 +581,10 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
filter.Types, filter.NotTypes, filter.Types, filter.NotTypes,
nil, filter.ContainsURL, filter.Limit, FilterOrderDesc, nil, filter.ContainsURL, filter.Limit, FilterOrderDesc,
) )
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, stmt, "SelectContextBeforeEvent: stmt.close() failed")
rows, err := stmt.QueryContext(ctx, params...) rows, err := stmt.QueryContext(ctx, params...)
if err != nil { if err != nil {
@ -596,6 +623,10 @@ func (s *outputRoomEventsStatements) SelectContextAfterEvent(
filter.Types, filter.NotTypes, filter.Types, filter.NotTypes,
nil, filter.ContainsURL, filter.Limit, FilterOrderAsc, nil, filter.ContainsURL, filter.Limit, FilterOrderAsc,
) )
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, stmt, "SelectContextAfterEvent: stmt.close() failed")
rows, err := stmt.QueryContext(ctx, params...) rows, err := stmt.QueryContext(ctx, params...)
if err != nil { if err != nil {

View file

@ -0,0 +1,198 @@
package tables_test
import (
"context"
"database/sql"
"reflect"
"sort"
"testing"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test"
)
func newMembershipsTable(t *testing.T, dbType test.DBType) (tables.Memberships, *sql.DB, func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
if err != nil {
t.Fatalf("failed to open db: %s", err)
}
var tab tables.Memberships
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresMembershipsTable(db)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSqliteMembershipsTable(db)
}
if err != nil {
t.Fatalf("failed to make new table: %s", err)
}
return tab, db, close
}
func TestMembershipsTable(t *testing.T) {
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
// Create users
var userEvents []*gomatrixserverlib.HeaderedEvent
users := []string{alice.ID}
for _, x := range room.CurrentState() {
if x.StateKeyEquals(alice.ID) {
if _, err := x.Membership(); err == nil {
userEvents = append(userEvents, x)
break
}
}
}
if len(userEvents) == 0 {
t.Fatalf("didn't find creator membership event")
}
for i := 0; i < 10; i++ {
u := test.NewUser(t)
users = append(users, u.ID)
ev := room.CreateAndInsert(t, u, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "join",
}, test.WithStateKey(u.ID))
userEvents = append(userEvents, ev)
}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
table, _, close := newMembershipsTable(t, dbType)
defer close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
for _, ev := range userEvents {
if err := table.UpsertMembership(ctx, nil, ev, types.StreamPosition(ev.Depth()), 1); err != nil {
t.Fatalf("failed to upsert membership: %s", err)
}
}
testUpsert(t, ctx, table, userEvents[0], alice, room)
testMembershipCount(t, ctx, table, room)
testHeroes(t, ctx, table, alice, room, users)
})
}
func testHeroes(t *testing.T, ctx context.Context, table tables.Memberships, user *test.User, room *test.Room, users []string) {
// Re-slice and sort the expected users
users = users[1:]
sort.Strings(users)
type testCase struct {
name string
memberships []string
wantHeroes []string
}
testCases := []testCase{
{name: "no memberships queried", memberships: []string{}},
{name: "joined memberships queried should be limited", memberships: []string{gomatrixserverlib.Join}, wantHeroes: users[:5]},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, err := table.SelectHeroes(ctx, nil, room.ID, user.ID, tc.memberships)
if err != nil {
t.Fatalf("unable to select heroes: %s", err)
}
if gotLen := len(got); gotLen != len(tc.wantHeroes) {
t.Fatalf("expected %d heroes, got %d", len(tc.wantHeroes), gotLen)
}
if !reflect.DeepEqual(got, tc.wantHeroes) {
t.Fatalf("expected heroes to be %+v, got %+v", tc.wantHeroes, got)
}
})
}
}
func testMembershipCount(t *testing.T, ctx context.Context, table tables.Memberships, room *test.Room) {
t.Run("membership counts are correct", func(t *testing.T) {
// After 10 events, we should have 6 users (5 create related [incl. one member event], 5 member events = 6 users)
count, err := table.SelectMembershipCount(ctx, nil, room.ID, gomatrixserverlib.Join, 10)
if err != nil {
t.Fatalf("failed to get membership count: %s", err)
}
expectedCount := 6
if expectedCount != count {
t.Fatalf("expected member count to be %d, got %d", expectedCount, count)
}
// After 100 events, we should have all 11 users
count, err = table.SelectMembershipCount(ctx, nil, room.ID, gomatrixserverlib.Join, 100)
if err != nil {
t.Fatalf("failed to get membership count: %s", err)
}
expectedCount = 11
if expectedCount != count {
t.Fatalf("expected member count to be %d, got %d", expectedCount, count)
}
})
}
func testUpsert(t *testing.T, ctx context.Context, table tables.Memberships, membershipEvent *gomatrixserverlib.HeaderedEvent, user *test.User, room *test.Room) {
t.Run("upserting works as expected", func(t *testing.T) {
if err := table.UpsertMembership(ctx, nil, membershipEvent, 1, 1); err != nil {
t.Fatalf("failed to upsert membership: %s", err)
}
membership, pos, err := table.SelectMembershipForUser(ctx, nil, room.ID, user.ID, 1)
if err != nil {
t.Fatalf("failed to select membership: %s", err)
}
expectedPos := 1
if pos != expectedPos {
t.Fatalf("expected pos to be %d, got %d", expectedPos, pos)
}
if membership != gomatrixserverlib.Join {
t.Fatalf("expected membership to be join, got %s", membership)
}
// Create a new event which gets upserted and should not cause issues
ev := room.CreateAndInsert(t, user, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": gomatrixserverlib.Join,
}, test.WithStateKey(user.ID))
// Insert the same event again, but with different positions, which should get updated
if err = table.UpsertMembership(ctx, nil, ev, 2, 2); err != nil {
t.Fatalf("failed to upsert membership: %s", err)
}
// Verify the position got updated
membership, pos, err = table.SelectMembershipForUser(ctx, nil, room.ID, user.ID, 10)
if err != nil {
t.Fatalf("failed to select membership: %s", err)
}
expectedPos = 2
if pos != expectedPos {
t.Fatalf("expected pos to be %d, got %d", expectedPos, pos)
}
if membership != gomatrixserverlib.Join {
t.Fatalf("expected membership to be join, got %s", membership)
}
// If we can't find a membership, it should default to leave
if membership, _, err = table.SelectMembershipForUser(ctx, nil, room.ID, user.ID, 1); err != nil {
t.Fatalf("failed to select membership: %s", err)
}
if membership != gomatrixserverlib.Leave {
t.Fatalf("expected membership to be leave, got %s", membership)
}
})
}

View file

@ -74,7 +74,12 @@ func (p *InviteStreamProvider) IncrementalSync(
return to return to
} }
for roomID := range retiredInvites { for roomID := range retiredInvites {
if _, ok := req.Response.Rooms.Join[roomID]; !ok { if _, ok := req.Response.Rooms.Invite[roomID]; ok {
continue
}
if _, ok := req.Response.Rooms.Join[roomID]; ok {
continue
}
lr := types.NewLeaveResponse() lr := types.NewLeaveResponse()
h := sha256.Sum256(append([]byte(roomID), []byte(strconv.FormatInt(int64(to), 10))...)) h := sha256.Sum256(append([]byte(roomID), []byte(strconv.FormatInt(int64(to), 10))...))
lr.Timeline.Events = append(lr.Timeline.Events, gomatrixserverlib.ClientEvent{ lr.Timeline.Events = append(lr.Timeline.Events, gomatrixserverlib.ClientEvent{
@ -88,7 +93,7 @@ func (p *InviteStreamProvider) IncrementalSync(
Content: gomatrixserverlib.RawJSON(`{"membership":"leave"}`), Content: gomatrixserverlib.RawJSON(`{"membership":"leave"}`),
}) })
req.Response.Rooms.Leave[roomID] = lr req.Response.Rooms.Leave[roomID] = lr
}
} }
return maxID return maxID

View file

@ -194,7 +194,7 @@ func (p *PDUStreamProvider) IncrementalSync(
} }
} }
var pos types.StreamPosition var pos types.StreamPosition
if pos, err = p.addRoomDeltaToResponse(ctx, snapshot, req.Device, newRange, delta, &eventFilter, &stateFilter, req.Response); err != nil { if pos, err = p.addRoomDeltaToResponse(ctx, snapshot, req.Device, newRange, delta, &eventFilter, &stateFilter, req); err != nil {
req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed") req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed")
if err == context.DeadlineExceeded || err == context.Canceled || err == sql.ErrTxDone { if err == context.DeadlineExceeded || err == context.Canceled || err == sql.ErrTxDone {
return newPos return newPos
@ -225,7 +225,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
delta types.StateDelta, delta types.StateDelta,
eventFilter *gomatrixserverlib.RoomEventFilter, eventFilter *gomatrixserverlib.RoomEventFilter,
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
res *types.Response, req *types.SyncRequest,
) (types.StreamPosition, error) { ) (types.StreamPosition, error) {
if delta.MembershipPos > 0 && delta.Membership == gomatrixserverlib.Leave { if delta.MembershipPos > 0 && delta.Membership == gomatrixserverlib.Leave {
// make sure we don't leak recent events after the leave event. // make sure we don't leak recent events after the leave event.
@ -290,8 +290,10 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
hasMembershipChange := false hasMembershipChange := false
for _, recentEvent := range recentStreamEvents { for _, recentEvent := range recentStreamEvents {
if recentEvent.Type() == gomatrixserverlib.MRoomMember && recentEvent.StateKey() != nil { if recentEvent.Type() == gomatrixserverlib.MRoomMember && recentEvent.StateKey() != nil {
if membership, _ := recentEvent.Membership(); membership == gomatrixserverlib.Join {
req.MembershipChanges[*recentEvent.StateKey()] = struct{}{}
}
hasMembershipChange = true hasMembershipChange = true
break
} }
} }
@ -318,9 +320,9 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync)
// If we are limited by the filter AND the history visibility filter // If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited. // didn't "remove" events, return that the response is limited.
jr.Timeline.Limited = limited && len(events) == len(recentEvents) jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync) jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync)
res.Rooms.Join[delta.RoomID] = jr req.Response.Rooms.Join[delta.RoomID] = jr
case gomatrixserverlib.Peek: case gomatrixserverlib.Peek:
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
@ -329,7 +331,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited jr.Timeline.Limited = limited
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync) jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync)
res.Rooms.Peek[delta.RoomID] = jr req.Response.Rooms.Peek[delta.RoomID] = jr
case gomatrixserverlib.Leave: case gomatrixserverlib.Leave:
fallthrough // transitions to leave are the same as ban fallthrough // transitions to leave are the same as ban
@ -342,7 +344,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
// didn't "remove" events, return that the response is limited. // didn't "remove" events, return that the response is limited.
lr.Timeline.Limited = limited && len(events) == len(recentEvents) lr.Timeline.Limited = limited && len(events) == len(recentEvents)
lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync) lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync)
res.Rooms.Leave[delta.RoomID] = lr req.Response.Rooms.Leave[delta.RoomID] = lr
} }
return latestPosition, nil return latestPosition, nil

View file

@ -101,6 +101,7 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
Timeout: timeout, // Timeout: timeout, //
Rooms: make(map[string]string), // Populated by the PDU stream Rooms: make(map[string]string), // Populated by the PDU stream
WantFullState: wantFullState, // WantFullState: wantFullState, //
MembershipChanges: make(map[string]struct{}), // Populated by the PDU stream
}, nil }, nil
} }

View file

@ -4,9 +4,10 @@ import (
"context" "context"
"time" "time"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
userapi "github.com/matrix-org/dendrite/userapi/api"
) )
type SyncRequest struct { type SyncRequest struct {
@ -22,6 +23,8 @@ type SyncRequest struct {
// Updated by the PDU stream. // Updated by the PDU stream.
Rooms map[string]string Rooms map[string]string
// Updated by the PDU stream. // Updated by the PDU stream.
MembershipChanges map[string]struct{}
// Updated by the PDU stream.
IgnoredUsers IgnoredUsers IgnoredUsers IgnoredUsers
} }

View file

@ -22,10 +22,6 @@ Forgotten room messages cannot be paginated
Local device key changes get to remote servers with correct prev_id Local device key changes get to remote servers with correct prev_id
# Flakey
Local device key changes appear in /keys/changes
# we don't support groups # we don't support groups
Remove group category Remove group category
@ -39,12 +35,6 @@ Events in rooms with AS-hosted room aliases are sent to AS server
Inviting an AS-hosted user asks the AS server Inviting an AS-hosted user asks the AS server
Accesing an AS-hosted room alias asks the AS server Accesing an AS-hosted room alias asks the AS server
# Flakey, need additional investigation
Messages that notify from another user increment notification_count
Messages that highlight from another user increment unread highlight count
Notifications can be viewed with GET /notifications
# More flakey # More flakey
Guest users can join guest_access rooms Guest users can join guest_access rooms

View file

@ -746,3 +746,9 @@ Inbound federation can return missing events for joined visibility
outliers whose auth_events are in a different room are correctly rejected outliers whose auth_events are in a different room are correctly rejected
Messages that notify from another user increment notification_count Messages that notify from another user increment notification_count
Messages that highlight from another user increment unread highlight count Messages that highlight from another user increment unread highlight count
Newly joined room has correct timeline in incremental sync
When user joins a room the state is included in the next sync
When user joins a room the state is included in a gapped sync
Messages that notify from another user increment notification_count
Messages that highlight from another user increment unread highlight count
Notifications can be viewed with GET /notifications

View file

@ -96,7 +96,7 @@ type ClientUserAPI interface {
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
SetAvatarURL(ctx context.Context, req *PerformSetAvatarURLRequest, res *PerformSetAvatarURLResponse) error SetAvatarURL(ctx context.Context, req *PerformSetAvatarURLRequest, res *PerformSetAvatarURLResponse) error
SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *struct{}) error SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *PerformUpdateDisplayNameResponse) error
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error
@ -579,7 +579,10 @@ type Notification struct {
type PerformSetAvatarURLRequest struct { type PerformSetAvatarURLRequest struct {
Localpart, AvatarURL string Localpart, AvatarURL string
} }
type PerformSetAvatarURLResponse struct{} type PerformSetAvatarURLResponse struct {
Profile *authtypes.Profile `json:"profile"`
Changed bool `json:"changed"`
}
type QueryNumericLocalpartResponse struct { type QueryNumericLocalpartResponse struct {
ID int64 ID int64
@ -606,6 +609,11 @@ type PerformUpdateDisplayNameRequest struct {
Localpart, DisplayName string Localpart, DisplayName string
} }
type PerformUpdateDisplayNameResponse struct {
Profile *authtypes.Profile `json:"profile"`
Changed bool `json:"changed"`
}
type QueryLocalpartForThreePIDRequest struct { type QueryLocalpartForThreePIDRequest struct {
ThreePID, Medium string ThreePID, Medium string
} }

View file

@ -168,7 +168,7 @@ func (t *UserInternalAPITrace) QueryAccountAvailability(ctx context.Context, req
return err return err
} }
func (t *UserInternalAPITrace) SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *struct{}) error { func (t *UserInternalAPITrace) SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *PerformUpdateDisplayNameResponse) error {
err := t.Impl.SetDisplayName(ctx, req, res) err := t.Impl.SetDisplayName(ctx, req, res)
util.GetLogger(ctx).Infof("SetDisplayName req=%+v res=%+v", js(req), js(res)) util.GetLogger(ctx).Infof("SetDisplayName req=%+v res=%+v", js(req), js(res))
return err return err

View file

@ -81,7 +81,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats
readPos := msg.Header.Get(jetstream.EventID) readPos := msg.Header.Get(jetstream.EventID)
evType := msg.Header.Get("type") evType := msg.Header.Get("type")
if readPos == "" || evType != "m.read" { if readPos == "" || (evType != "m.read" && evType != "m.read.private") {
return true return true
} }

View file

@ -10,19 +10,24 @@ import (
"github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage"
) )
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
t.Helper() t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{ db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr), ConnectionString: config.DataSource(connStr),
}, "", 4, 0, 0, "") }, "", 4, 0, 0, "")
if err != nil { if err != nil {
t.Fatalf("failed to create new user db: %v", err) t.Fatalf("failed to create new user db: %v", err)
} }
return db, close return db, func() {
close()
baseclose()
}
} }
func mustCreateEvent(t *testing.T, content string) *gomatrixserverlib.HeaderedEvent { func mustCreateEvent(t *testing.T, content string) *gomatrixserverlib.HeaderedEvent {

View file

@ -170,7 +170,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
return nil return nil
} }
if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
return err return err
} }
@ -813,7 +813,10 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush
} }
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error { func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error {
return a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL) profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL)
res.Profile = profile
res.Changed = changed
return err
} }
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error { func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error {
@ -847,8 +850,11 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q
} }
} }
func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, _ *struct{}) error { func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error {
return a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName) profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName)
res.Profile = profile
res.Changed = changed
return err
} }
func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error { func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {

View file

@ -388,7 +388,7 @@ func (h *httpUserInternalAPI) QueryAccountByPassword(
func (h *httpUserInternalAPI) SetDisplayName( func (h *httpUserInternalAPI) SetDisplayName(
ctx context.Context, ctx context.Context,
request *api.PerformUpdateDisplayNameRequest, request *api.PerformUpdateDisplayNameRequest,
response *struct{}, response *api.PerformUpdateDisplayNameResponse,
) error { ) error {
return httputil.CallInternalRPCAPI( return httputil.CallInternalRPCAPI(
"SetDisplayName", h.apiURL+PerformSetDisplayNamePath, "SetDisplayName", h.apiURL+PerformSetDisplayNamePath,

View file

@ -29,8 +29,8 @@ import (
type Profile interface { type Profile interface {
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, bool, error)
SetDisplayName(ctx context.Context, localpart string, displayName string) error SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, bool, error)
} }
type Account interface { type Account interface {

View file

@ -26,7 +26,7 @@ import (
const accountDataSchema = ` const accountDataSchema = `
-- Stores data about accounts data. -- Stores data about accounts data.
CREATE TABLE IF NOT EXISTS account_data ( CREATE TABLE IF NOT EXISTS userapi_account_datas (
-- The Matrix user ID localpart for this account -- The Matrix user ID localpart for this account
localpart TEXT NOT NULL, localpart TEXT NOT NULL,
-- The room ID for this data (empty string if not specific to a room) -- The room ID for this data (empty string if not specific to a room)
@ -41,15 +41,15 @@ CREATE TABLE IF NOT EXISTS account_data (
` `
const insertAccountDataSQL = ` const insertAccountDataSQL = `
INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4) INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content
` `
const selectAccountDataSQL = "" + const selectAccountDataSQL = "" +
"SELECT room_id, type, content FROM account_data WHERE localpart = $1" "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1"
const selectAccountDataByTypeSQL = "" + const selectAccountDataByTypeSQL = "" +
"SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3" "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3"
type accountDataStatements struct { type accountDataStatements struct {
insertAccountDataStmt *sql.Stmt insertAccountDataStmt *sql.Stmt

View file

@ -32,7 +32,7 @@ import (
const accountsSchema = ` const accountsSchema = `
-- Stores data about accounts. -- Stores data about accounts.
CREATE TABLE IF NOT EXISTS account_accounts ( CREATE TABLE IF NOT EXISTS userapi_accounts (
-- The Matrix user ID localpart for this account -- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY, localpart TEXT NOT NULL PRIMARY KEY,
-- When this account was first created, as a unix timestamp (ms resolution). -- When this account was first created, as a unix timestamp (ms resolution).
@ -51,22 +51,22 @@ CREATE TABLE IF NOT EXISTS account_accounts (
` `
const insertAccountSQL = "" + const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" "INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
const updatePasswordSQL = "" + const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2"
const deactivateAccountSQL = "" + const deactivateAccountSQL = "" +
"UPDATE account_accounts SET is_deactivated = TRUE WHERE localpart = $1" "UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1"
const selectAccountByLocalpartSQL = "" + const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1" "SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1"
const selectPasswordHashSQL = "" + const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE" "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
const selectNewNumericLocalpartSQL = "" + const selectNewNumericLocalpartSQL = "" +
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM account_accounts WHERE localpart ~ '^[0-9]{1,}$'" "SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$'"
type accountsStatements struct { type accountsStatements struct {
insertAccountStmt *sql.Stmt insertAccountStmt *sql.Stmt

View file

@ -7,7 +7,7 @@ import (
) )
func UpIsActive(ctx context.Context, tx *sql.Tx) error { func UpIsActive(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS is_deactivated BOOLEAN DEFAULT FALSE;") _, err := tx.ExecContext(ctx, "ALTER TABLE userapi_accounts ADD COLUMN IF NOT EXISTS is_deactivated BOOLEAN DEFAULT FALSE;")
if err != nil { if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err) return fmt.Errorf("failed to execute upgrade: %w", err)
} }
@ -15,7 +15,7 @@ func UpIsActive(ctx context.Context, tx *sql.Tx) error {
} }
func DownIsActive(ctx context.Context, tx *sql.Tx) error { func DownIsActive(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts DROP COLUMN is_deactivated;") _, err := tx.ExecContext(ctx, "ALTER TABLE userapi_accounts DROP COLUMN is_deactivated;")
if err != nil { if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err) return fmt.Errorf("failed to execute downgrade: %w", err)
} }

View file

@ -8,9 +8,9 @@ import (
func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, ` _, err := tx.ExecContext(ctx, `
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS last_seen_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM CURRENT_TIMESTAMP)*1000; ALTER TABLE userapi_devices ADD COLUMN IF NOT EXISTS last_seen_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM CURRENT_TIMESTAMP)*1000;
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS ip TEXT; ALTER TABLE userapi_devices ADD COLUMN IF NOT EXISTS ip TEXT;
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`) ALTER TABLE userapi_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`)
if err != nil { if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err) return fmt.Errorf("failed to execute upgrade: %w", err)
} }
@ -19,9 +19,9 @@ ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`)
func DownLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { func DownLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, ` _, err := tx.ExecContext(ctx, `
ALTER TABLE device_devices DROP COLUMN last_seen_ts; ALTER TABLE userapi_devices DROP COLUMN last_seen_ts;
ALTER TABLE device_devices DROP COLUMN ip; ALTER TABLE userapi_devices DROP COLUMN ip;
ALTER TABLE device_devices DROP COLUMN user_agent;`) ALTER TABLE userapi_devices DROP COLUMN user_agent;`)
if err != nil { if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err) return fmt.Errorf("failed to execute downgrade: %w", err)
} }

View file

@ -9,10 +9,10 @@ import (
func UpAddAccountType(ctx context.Context, tx *sql.Tx) error { func UpAddAccountType(ctx context.Context, tx *sql.Tx) error {
// initially set every account to useraccount, change appservice and guest accounts afterwards // initially set every account to useraccount, change appservice and guest accounts afterwards
// (user = 1, guest = 2, admin = 3, appservice = 4) // (user = 1, guest = 2, admin = 3, appservice = 4)
_, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS account_type SMALLINT NOT NULL DEFAULT 1; _, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts ADD COLUMN IF NOT EXISTS account_type SMALLINT NOT NULL DEFAULT 1;
UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> ''; UPDATE userapi_accounts SET account_type = 4 WHERE appservice_id <> '';
UPDATE account_accounts SET account_type = 2 WHERE localpart ~ '^[0-9]+$'; UPDATE userapi_accounts SET account_type = 2 WHERE localpart ~ '^[0-9]+$';
ALTER TABLE account_accounts ALTER COLUMN account_type DROP DEFAULT;`, ALTER TABLE userapi_accounts ALTER COLUMN account_type DROP DEFAULT;`,
) )
if err != nil { if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err) return fmt.Errorf("failed to execute upgrade: %w", err)
@ -21,7 +21,7 @@ ALTER TABLE account_accounts ALTER COLUMN account_type DROP DEFAULT;`,
} }
func DownAddAccountType(ctx context.Context, tx *sql.Tx) error { func DownAddAccountType(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts DROP COLUMN account_type;") _, err := tx.ExecContext(ctx, "ALTER TABLE userapi_accounts DROP COLUMN account_type;")
if err != nil { if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err) return fmt.Errorf("failed to execute downgrade: %w", err)
} }

View file

@ -8,7 +8,7 @@ import (
func UpNoGuests(ctx context.Context, tx *sql.Tx) error { func UpNoGuests(ctx context.Context, tx *sql.Tx) error {
// AddAccountType introduced a bug where each user that had was registered as a regular user, but without user_id, became a guest. // AddAccountType introduced a bug where each user that had was registered as a regular user, but without user_id, became a guest.
_, err := tx.ExecContext(ctx, "UPDATE account_accounts SET account_type = 1 WHERE account_type = 2;") _, err := tx.ExecContext(ctx, "UPDATE userapi_accounts SET account_type = 1 WHERE account_type = 2;")
if err != nil { if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err) return fmt.Errorf("failed to execute upgrade: %w", err)
} }

View file

@ -0,0 +1,102 @@
package deltas
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
)
var renameTableMappings = map[string]string{
"account_accounts": "userapi_accounts",
"account_data": "userapi_account_datas",
"device_devices": "userapi_devices",
"account_e2e_room_keys": "userapi_key_backups",
"account_e2e_room_keys_versions": "userapi_key_backup_versions",
"login_tokens": "userapi_login_tokens",
"open_id_tokens": "userapi_openid_tokens",
"account_profiles": "userapi_profiles",
"account_threepid": "userapi_threepids",
}
var renameSequenceMappings = map[string]string{
"device_session_id_seq": "userapi_device_session_id_seq",
"account_e2e_room_keys_versions_seq": "userapi_key_backup_versions_seq",
}
var renameIndicesMappings = map[string]string{
"device_localpart_id_idx": "userapi_device_localpart_id_idx",
"e2e_room_keys_idx": "userapi_key_backups_idx",
"e2e_room_keys_versions_idx": "userapi_key_backups_versions_idx",
"account_e2e_room_keys_versions_idx": "userapi_key_backup_versions_idx",
"login_tokens_expiration_idx": "userapi_login_tokens_expiration_idx",
"account_threepid_localpart": "userapi_threepid_idx",
}
// I know what you're thinking: you're wondering "why doesn't this use $1
// and pass variadic parameters to ExecContext?" — the answer is because
// PostgreSQL doesn't expect the table name to be specified as a substituted
// argument in that way so it results in a syntax error in the query.
func UpRenameTables(ctx context.Context, tx *sql.Tx) error {
for old, new := range renameTableMappings {
q := fmt.Sprintf(
"ALTER TABLE IF EXISTS %s RENAME TO %s;",
pq.QuoteIdentifier(old), pq.QuoteIdentifier(new),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("rename table %q to %q error: %w", old, new, err)
}
}
for old, new := range renameSequenceMappings {
q := fmt.Sprintf(
"ALTER SEQUENCE IF EXISTS %s RENAME TO %s;",
pq.QuoteIdentifier(old), pq.QuoteIdentifier(new),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("rename table %q to %q error: %w", old, new, err)
}
}
for old, new := range renameIndicesMappings {
q := fmt.Sprintf(
"ALTER INDEX IF EXISTS %s RENAME TO %s;",
pq.QuoteIdentifier(old), pq.QuoteIdentifier(new),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("rename table %q to %q error: %w", old, new, err)
}
}
return nil
}
func DownRenameTables(ctx context.Context, tx *sql.Tx) error {
for old, new := range renameTableMappings {
q := fmt.Sprintf(
"ALTER TABLE IF EXISTS %s RENAME TO %s;",
pq.QuoteIdentifier(new), pq.QuoteIdentifier(old),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("rename table %q to %q error: %w", new, old, err)
}
}
for old, new := range renameSequenceMappings {
q := fmt.Sprintf(
"ALTER SEQUENCE IF EXISTS %s RENAME TO %s;",
pq.QuoteIdentifier(new), pq.QuoteIdentifier(old),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("rename table %q to %q error: %w", new, old, err)
}
}
for old, new := range renameIndicesMappings {
q := fmt.Sprintf(
"ALTER INDEX IF EXISTS %s RENAME TO %s;",
pq.QuoteIdentifier(new), pq.QuoteIdentifier(old),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("rename table %q to %q error: %w", new, old, err)
}
}
return nil
}

View file

@ -31,10 +31,10 @@ import (
const devicesSchema = ` const devicesSchema = `
-- This sequence is used for automatic allocation of session_id. -- This sequence is used for automatic allocation of session_id.
CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1; CREATE SEQUENCE IF NOT EXISTS userapi_device_session_id_seq START 1;
-- Stores data about devices. -- Stores data about devices.
CREATE TABLE IF NOT EXISTS device_devices ( CREATE TABLE IF NOT EXISTS userapi_devices (
-- The access token granted to this device. This has to be the primary key -- The access token granted to this device. This has to be the primary key
-- so we can distinguish which device is making a given request. -- so we can distinguish which device is making a given request.
access_token TEXT NOT NULL PRIMARY KEY, access_token TEXT NOT NULL PRIMARY KEY,
@ -42,7 +42,7 @@ CREATE TABLE IF NOT EXISTS device_devices (
-- This can be used as a secure substitution of the access token in situations -- This can be used as a secure substitution of the access token in situations
-- where data is associated with access tokens (e.g. transaction storage), -- where data is associated with access tokens (e.g. transaction storage),
-- so we don't have to store users' access tokens everywhere. -- so we don't have to store users' access tokens everywhere.
session_id BIGINT NOT NULL DEFAULT nextval('device_session_id_seq'), session_id BIGINT NOT NULL DEFAULT nextval('userapi_device_session_id_seq'),
-- The device identifier. This only needs to uniquely identify a device for a given user, not globally. -- The device identifier. This only needs to uniquely identify a device for a given user, not globally.
-- access_tokens will be clobbered based on the device ID for a user. -- access_tokens will be clobbered based on the device ID for a user.
device_id TEXT NOT NULL, device_id TEXT NOT NULL,
@ -65,39 +65,39 @@ CREATE TABLE IF NOT EXISTS device_devices (
); );
-- Device IDs must be unique for a given user. -- Device IDs must be unique for a given user.
CREATE UNIQUE INDEX IF NOT EXISTS device_localpart_id_idx ON device_devices(localpart, device_id); CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, device_id);
` `
const insertDeviceSQL = "" + const insertDeviceSQL = "" +
"INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + "INSERT INTO userapi_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" +
" RETURNING session_id" " RETURNING session_id"
const selectDeviceByTokenSQL = "" + const selectDeviceByTokenSQL = "" +
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" "SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" + const selectDeviceByIDSQL = "" +
"SELECT display_name, last_seen_ts, ip FROM device_devices WHERE localpart = $1 and device_id = $2" "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" + const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
const updateDeviceNameSQL = "" + const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
const deleteDeviceSQL = "" + const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" + const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2" "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
const deleteDevicesSQL = "" + const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)" "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id = ANY($2)"
const selectDevicesByIDSQL = "" + const selectDevicesByIDSQL = "" +
"SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC" "SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
const updateDeviceLastSeen = "" + const updateDeviceLastSeen = "" +
"UPDATE device_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
type devicesStatements struct { type devicesStatements struct {
insertDeviceStmt *sql.Stmt insertDeviceStmt *sql.Stmt

View file

@ -26,7 +26,7 @@ import (
) )
const keyBackupTableSchema = ` const keyBackupTableSchema = `
CREATE TABLE IF NOT EXISTS account_e2e_room_keys ( CREATE TABLE IF NOT EXISTS userapi_key_backups (
user_id TEXT NOT NULL, user_id TEXT NOT NULL,
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
session_id TEXT NOT NULL, session_id TEXT NOT NULL,
@ -37,31 +37,31 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
is_verified BOOLEAN NOT NULL, is_verified BOOLEAN NOT NULL,
session_data TEXT NOT NULL session_data TEXT NOT NULL
); );
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version); CREATE UNIQUE INDEX IF NOT EXISTS userapi_key_backups_idx ON userapi_key_backups(user_id, room_id, session_id, version);
CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version); CREATE INDEX IF NOT EXISTS userapi_key_backups_versions_idx ON userapi_key_backups(user_id, version);
` `
const insertBackupKeySQL = "" + const insertBackupKeySQL = "" +
"INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " + "INSERT INTO userapi_key_backups(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
const updateBackupKeySQL = "" + const updateBackupKeySQL = "" +
"UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " + "UPDATE userapi_key_backups SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " +
"WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8" "WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8"
const countKeysSQL = "" + const countKeysSQL = "" +
"SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2" "SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2"
const selectKeysSQL = "" + const selectKeysSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2" "WHERE user_id = $1 AND version = $2"
const selectKeysByRoomIDSQL = "" + const selectKeysByRoomIDSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2 AND room_id = $3" "WHERE user_id = $1 AND version = $2 AND room_id = $3"
const selectKeysByRoomIDAndSessionIDSQL = "" + const selectKeysByRoomIDAndSessionIDSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4" "WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4"
type keyBackupStatements struct { type keyBackupStatements struct {

View file

@ -26,40 +26,40 @@ import (
) )
const keyBackupVersionTableSchema = ` const keyBackupVersionTableSchema = `
CREATE SEQUENCE IF NOT EXISTS account_e2e_room_keys_versions_seq; CREATE SEQUENCE IF NOT EXISTS userapi_key_backup_versions_seq;
-- the metadata for each generation of encrypted e2e session backups -- the metadata for each generation of encrypted e2e session backups
CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions ( CREATE TABLE IF NOT EXISTS userapi_key_backup_versions (
user_id TEXT NOT NULL, user_id TEXT NOT NULL,
-- this means no 2 users will ever have the same version of e2e session backups which strictly -- this means no 2 users will ever have the same version of e2e session backups which strictly
-- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1. -- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1.
version BIGINT DEFAULT nextval('account_e2e_room_keys_versions_seq'), version BIGINT DEFAULT nextval('userapi_key_backup_versions_seq'),
algorithm TEXT NOT NULL, algorithm TEXT NOT NULL,
auth_data TEXT NOT NULL, auth_data TEXT NOT NULL,
etag TEXT NOT NULL, etag TEXT NOT NULL,
deleted SMALLINT DEFAULT 0 NOT NULL deleted SMALLINT DEFAULT 0 NOT NULL
); );
CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); CREATE UNIQUE INDEX IF NOT EXISTS userapi_key_backup_versions_idx ON userapi_key_backup_versions(user_id, version);
` `
const insertKeyBackupSQL = "" + const insertKeyBackupSQL = "" +
"INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version" "INSERT INTO userapi_key_backup_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version"
const updateKeyBackupAuthDataSQL = "" + const updateKeyBackupAuthDataSQL = "" +
"UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" "UPDATE userapi_key_backup_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3"
const updateKeyBackupETagSQL = "" + const updateKeyBackupETagSQL = "" +
"UPDATE account_e2e_room_keys_versions SET etag = $1 WHERE user_id = $2 AND version = $3" "UPDATE userapi_key_backup_versions SET etag = $1 WHERE user_id = $2 AND version = $3"
const deleteKeyBackupSQL = "" + const deleteKeyBackupSQL = "" +
"UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2" "UPDATE userapi_key_backup_versions SET deleted=1 WHERE user_id = $1 AND version = $2"
const selectKeyBackupSQL = "" + const selectKeyBackupSQL = "" +
"SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2" "SELECT algorithm, auth_data, etag, deleted FROM userapi_key_backup_versions WHERE user_id = $1 AND version = $2"
const selectLatestVersionSQL = "" + const selectLatestVersionSQL = "" +
"SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1" "SELECT MAX(version) FROM userapi_key_backup_versions WHERE user_id = $1"
type keyBackupVersionStatements struct { type keyBackupVersionStatements struct {
insertKeyBackupStmt *sql.Stmt insertKeyBackupStmt *sql.Stmt

View file

@ -26,7 +26,7 @@ import (
) )
const loginTokenSchema = ` const loginTokenSchema = `
CREATE TABLE IF NOT EXISTS login_tokens ( CREATE TABLE IF NOT EXISTS userapi_login_tokens (
-- The random value of the token issued to a user -- The random value of the token issued to a user
token TEXT NOT NULL PRIMARY KEY, token TEXT NOT NULL PRIMARY KEY,
-- When the token expires -- When the token expires
@ -37,17 +37,17 @@ CREATE TABLE IF NOT EXISTS login_tokens (
); );
-- This index allows efficient garbage collection of expired tokens. -- This index allows efficient garbage collection of expired tokens.
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); CREATE INDEX IF NOT EXISTS userapi_login_tokens_expiration_idx ON userapi_login_tokens(token_expires_at);
` `
const insertLoginTokenSQL = "" + const insertLoginTokenSQL = "" +
"INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)" "INSERT INTO userapi_login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
const deleteLoginTokenSQL = "" + const deleteLoginTokenSQL = "" +
"DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2" "DELETE FROM userapi_login_tokens WHERE token = $1 OR token_expires_at <= $2"
const selectLoginTokenSQL = "" + const selectLoginTokenSQL = "" +
"SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2" "SELECT user_id FROM userapi_login_tokens WHERE token = $1 AND token_expires_at > $2"
type loginTokenStatements struct { type loginTokenStatements struct {
insertStmt *sql.Stmt insertStmt *sql.Stmt
@ -78,7 +78,7 @@ func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx
// deleteByToken removes the named token. // deleteByToken removes the named token.
// //
// As a simple way to garbage-collect stale tokens, we also remove all expired tokens. // As a simple way to garbage-collect stale tokens, we also remove all expired tokens.
// The login_tokens_expiration_idx index should make that efficient. // The userapi_login_tokens_expiration_idx index should make that efficient.
func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error { func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error {
stmt := sqlutil.TxStmt(txn, s.deleteStmt) stmt := sqlutil.TxStmt(txn, s.deleteStmt)
res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) res, err := stmt.ExecContext(ctx, token, time.Now().UTC())

View file

@ -13,7 +13,7 @@ import (
const openIDTokenSchema = ` const openIDTokenSchema = `
-- Stores data about openid tokens issued for accounts. -- Stores data about openid tokens issued for accounts.
CREATE TABLE IF NOT EXISTS open_id_tokens ( CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
-- The value of the token issued to a user -- The value of the token issued to a user
token TEXT NOT NULL PRIMARY KEY, token TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID for this account -- The Matrix user ID for this account
@ -24,10 +24,10 @@ CREATE TABLE IF NOT EXISTS open_id_tokens (
` `
const insertOpenIDTokenSQL = "" + const insertOpenIDTokenSQL = "" +
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" "INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
const selectOpenIDTokenSQL = "" + const selectOpenIDTokenSQL = "" +
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" "SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
type openIDTokenStatements struct { type openIDTokenStatements struct {
insertTokenStmt *sql.Stmt insertTokenStmt *sql.Stmt

View file

@ -27,7 +27,7 @@ import (
const profilesSchema = ` const profilesSchema = `
-- Stores data about accounts profiles. -- Stores data about accounts profiles.
CREATE TABLE IF NOT EXISTS account_profiles ( CREATE TABLE IF NOT EXISTS userapi_profiles (
-- The Matrix user ID localpart for this account -- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY, localpart TEXT NOT NULL PRIMARY KEY,
-- The display name for this account -- The display name for this account
@ -38,19 +38,27 @@ CREATE TABLE IF NOT EXISTS account_profiles (
` `
const insertProfileSQL = "" + const insertProfileSQL = "" +
"INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" "INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
const selectProfileByLocalpartSQL = "" + const selectProfileByLocalpartSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1" "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
const setAvatarURLSQL = "" + const setAvatarURLSQL = "" +
"UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2" "UPDATE userapi_profiles AS new" +
" SET avatar_url = $1" +
" FROM userapi_profiles AS old" +
" WHERE new.localpart = $2" +
" RETURNING new.display_name, old.avatar_url <> new.avatar_url"
const setDisplayNameSQL = "" + const setDisplayNameSQL = "" +
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" "UPDATE userapi_profiles AS new" +
" SET display_name = $1" +
" FROM userapi_profiles AS old" +
" WHERE new.localpart = $2" +
" RETURNING new.avatar_url, old.display_name <> new.display_name"
const selectProfilesBySearchSQL = "" + const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
type profilesStatements struct { type profilesStatements struct {
serverNoticesLocalpart string serverNoticesLocalpart string
@ -100,16 +108,28 @@ func (s *profilesStatements) SelectProfileByLocalpart(
func (s *profilesStatements) SetAvatarURL( func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
) (err error) { ) (*authtypes.Profile, bool, error) {
_, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart) profile := &authtypes.Profile{
return Localpart: localpart,
AvatarURL: avatarURL,
}
var changed bool
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName, &changed)
return profile, changed, err
} }
func (s *profilesStatements) SetDisplayName( func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string, ctx context.Context, txn *sql.Tx, localpart string, displayName string,
) (err error) { ) (*authtypes.Profile, bool, error) {
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) profile := &authtypes.Profile{
return Localpart: localpart,
DisplayName: displayName,
}
var changed bool
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL, &changed)
return profile, changed, err
} }
func (s *profilesStatements) SelectProfilesBySearch( func (s *profilesStatements) SelectProfilesBySearch(

View file

@ -45,7 +45,7 @@ CREATE INDEX IF NOT EXISTS userapi_daily_visits_localpart_timestamp_idx ON usera
const countUsersLastSeenAfterSQL = "" + const countUsersLastSeenAfterSQL = "" +
"SELECT COUNT(*) FROM (" + "SELECT COUNT(*) FROM (" +
" SELECT localpart FROM device_devices WHERE last_seen_ts > $1 " + " SELECT localpart FROM userapi_devices WHERE last_seen_ts > $1 " +
" GROUP BY localpart" + " GROUP BY localpart" +
" ) u" " ) u"
@ -62,7 +62,7 @@ R30Users counts the number of 30 day retained users, defined as:
const countR30UsersSQL = ` const countR30UsersSQL = `
SELECT platform, COUNT(*) FROM ( SELECT platform, COUNT(*) FROM (
SELECT users.localpart, platform, users.created_ts, MAX(uip.last_seen_ts) SELECT users.localpart, platform, users.created_ts, MAX(uip.last_seen_ts)
FROM account_accounts users FROM userapi_accounts users
INNER JOIN INNER JOIN
(SELECT (SELECT
localpart, last_seen_ts, localpart, last_seen_ts,
@ -75,7 +75,7 @@ SELECT platform, COUNT(*) FROM (
ELSE 'unknown' ELSE 'unknown'
END END
AS platform AS platform
FROM device_devices FROM userapi_devices
) uip ) uip
ON users.localpart = uip.localpart ON users.localpart = uip.localpart
AND users.account_type <> 4 AND users.account_type <> 4
@ -121,7 +121,7 @@ GROUP BY client_type
` `
const countUserByAccountTypeSQL = ` const countUserByAccountTypeSQL = `
SELECT COUNT(*) FROM account_accounts WHERE account_type = ANY($1) SELECT COUNT(*) FROM userapi_accounts WHERE account_type = ANY($1)
` `
// $1 = All non guest AccountType IDs // $1 = All non guest AccountType IDs
@ -134,7 +134,7 @@ SELECT user_type, COUNT(*) AS count FROM (
WHEN account_type = $2 AND appservice_id IS NULL THEN 'guest' WHEN account_type = $2 AND appservice_id IS NULL THEN 'guest'
WHEN account_type = ANY($1) AND appservice_id IS NOT NULL THEN 'bridged' WHEN account_type = ANY($1) AND appservice_id IS NOT NULL THEN 'bridged'
END AS user_type END AS user_type
FROM account_accounts FROM userapi_accounts
WHERE created_ts > $3 WHERE created_ts > $3
) AS t GROUP BY user_type ) AS t GROUP BY user_type
` `
@ -143,14 +143,14 @@ SELECT user_type, COUNT(*) AS count FROM (
const updateUserDailyVisitsSQL = ` const updateUserDailyVisitsSQL = `
INSERT INTO userapi_daily_visits(localpart, device_id, timestamp, user_agent) INSERT INTO userapi_daily_visits(localpart, device_id, timestamp, user_agent)
SELECT u.localpart, u.device_id, $1, MAX(u.user_agent) SELECT u.localpart, u.device_id, $1, MAX(u.user_agent)
FROM device_devices AS u FROM userapi_devices AS u
LEFT JOIN ( LEFT JOIN (
SELECT localpart, device_id, timestamp FROM userapi_daily_visits SELECT localpart, device_id, timestamp FROM userapi_daily_visits
WHERE timestamp = $1 WHERE timestamp = $1
) udv ) udv
ON u.localpart = udv.localpart AND u.device_id = udv.device_id ON u.localpart = udv.localpart AND u.device_id = udv.device_id
INNER JOIN device_devices d ON d.localpart = u.localpart INNER JOIN userapi_devices d ON d.localpart = u.localpart
INNER JOIN account_accounts a ON a.localpart = u.localpart INNER JOIN userapi_accounts a ON a.localpart = u.localpart
WHERE $2 <= d.last_seen_ts AND d.last_seen_ts < $3 WHERE $2 <= d.last_seen_ts AND d.last_seen_ts < $3
AND a.account_type in (1, 3) AND a.account_type in (1, 3)
GROUP BY u.localpart, u.device_id GROUP BY u.localpart, u.device_id

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/userapi/storage/shared" "github.com/matrix-org/dendrite/userapi/storage/shared"
// Import the postgres database driver. // Import the postgres database driver.
@ -36,6 +37,16 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "userapi: rename tables",
Up: deltas.UpRenameTables,
Down: deltas.DownRenameTables,
})
if err = m.Up(base.Context()); err != nil {
return nil, err
}
accountDataTable, err := NewPostgresAccountDataTable(db) accountDataTable, err := NewPostgresAccountDataTable(db)
if err != nil { if err != nil {
return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err) return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)

View file

@ -26,7 +26,7 @@ import (
const threepidSchema = ` const threepidSchema = `
-- Stores data about third party identifiers -- Stores data about third party identifiers
CREATE TABLE IF NOT EXISTS account_threepid ( CREATE TABLE IF NOT EXISTS userapi_threepids (
-- The third party identifier -- The third party identifier
threepid TEXT NOT NULL, threepid TEXT NOT NULL,
-- The 3PID medium -- The 3PID medium
@ -37,20 +37,20 @@ CREATE TABLE IF NOT EXISTS account_threepid (
PRIMARY KEY(threepid, medium) PRIMARY KEY(threepid, medium)
); );
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart); CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart);
` `
const selectLocalpartForThreePIDSQL = "" + const selectLocalpartForThreePIDSQL = "" +
"SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2" "SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
const selectThreePIDsForLocalpartSQL = "" + const selectThreePIDsForLocalpartSQL = "" +
"SELECT threepid, medium FROM account_threepid WHERE localpart = $1" "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1"
const insertThreePIDSQL = "" + const insertThreePIDSQL = "" +
"INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)" "INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)"
const deleteThreePIDSQL = "" + const deleteThreePIDSQL = "" +
"DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2" "DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
type threepidStatements struct { type threepidStatements struct {
selectLocalpartForThreePIDStmt *sql.Stmt selectLocalpartForThreePIDStmt *sql.Stmt

View file

@ -96,20 +96,24 @@ func (d *Database) GetProfileByLocalpart(
// localpart. Returns an error if something went wrong with the SQL query // localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL( func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string, ctx context.Context, localpart string, avatarURL string,
) error { ) (profile *authtypes.Profile, changed bool, err error) {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL) profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
return err
}) })
return
} }
// SetDisplayName updates the display name of the profile associated with the given // SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query // localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName( func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string, ctx context.Context, localpart string, displayName string,
) error { ) (profile *authtypes.Profile, changed bool, err error) {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Profiles.SetDisplayName(ctx, txn, localpart, displayName) profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
return err
}) })
return
} }
// SetPassword sets the account password to the given hash. // SetPassword sets the account password to the given hash.

View file

@ -25,7 +25,7 @@ import (
const accountDataSchema = ` const accountDataSchema = `
-- Stores data about accounts data. -- Stores data about accounts data.
CREATE TABLE IF NOT EXISTS account_data ( CREATE TABLE IF NOT EXISTS userapi_account_datas (
-- The Matrix user ID localpart for this account -- The Matrix user ID localpart for this account
localpart TEXT NOT NULL, localpart TEXT NOT NULL,
-- The room ID for this data (empty string if not specific to a room) -- The room ID for this data (empty string if not specific to a room)
@ -40,15 +40,15 @@ CREATE TABLE IF NOT EXISTS account_data (
` `
const insertAccountDataSQL = ` const insertAccountDataSQL = `
INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4) INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4 ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
` `
const selectAccountDataSQL = "" + const selectAccountDataSQL = "" +
"SELECT room_id, type, content FROM account_data WHERE localpart = $1" "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1"
const selectAccountDataByTypeSQL = "" + const selectAccountDataByTypeSQL = "" +
"SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3" "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3"
type accountDataStatements struct { type accountDataStatements struct {
db *sql.DB db *sql.DB

View file

@ -32,7 +32,7 @@ import (
const accountsSchema = ` const accountsSchema = `
-- Stores data about accounts. -- Stores data about accounts.
CREATE TABLE IF NOT EXISTS account_accounts ( CREATE TABLE IF NOT EXISTS userapi_accounts (
-- The Matrix user ID localpart for this account -- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY, localpart TEXT NOT NULL PRIMARY KEY,
-- When this account was first created, as a unix timestamp (ms resolution). -- When this account was first created, as a unix timestamp (ms resolution).
@ -51,22 +51,22 @@ CREATE TABLE IF NOT EXISTS account_accounts (
` `
const insertAccountSQL = "" + const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" "INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
const updatePasswordSQL = "" + const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2"
const deactivateAccountSQL = "" + const deactivateAccountSQL = "" +
"UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1" "UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1"
const selectAccountByLocalpartSQL = "" + const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1" "SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1"
const selectPasswordHashSQL = "" + const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0" "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = 0"
const selectNewNumericLocalpartSQL = "" + const selectNewNumericLocalpartSQL = "" +
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM account_accounts WHERE CAST(localpart AS INT) <> 0" "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0"
type accountsStatements struct { type accountsStatements struct {
db *sql.DB db *sql.DB

View file

@ -8,8 +8,8 @@ import (
func UpIsActive(ctx context.Context, tx *sql.Tx) error { func UpIsActive(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, ` _, err := tx.ExecContext(ctx, `
ALTER TABLE account_accounts RENAME TO account_accounts_tmp; ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
CREATE TABLE account_accounts ( CREATE TABLE userapi_accounts (
localpart TEXT NOT NULL PRIMARY KEY, localpart TEXT NOT NULL PRIMARY KEY,
created_ts BIGINT NOT NULL, created_ts BIGINT NOT NULL,
password_hash TEXT, password_hash TEXT,
@ -17,13 +17,13 @@ CREATE TABLE account_accounts (
is_deactivated BOOLEAN DEFAULT 0 is_deactivated BOOLEAN DEFAULT 0
); );
INSERT INSERT
INTO account_accounts ( INTO userapi_accounts (
localpart, created_ts, password_hash, appservice_id localpart, created_ts, password_hash, appservice_id
) SELECT ) SELECT
localpart, created_ts, password_hash, appservice_id localpart, created_ts, password_hash, appservice_id
FROM account_accounts_tmp FROM userapi_accounts_tmp
; ;
DROP TABLE account_accounts_tmp;`) DROP TABLE userapi_accounts_tmp;`)
if err != nil { if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err) return fmt.Errorf("failed to execute upgrade: %w", err)
} }
@ -32,21 +32,21 @@ DROP TABLE account_accounts_tmp;`)
func DownIsActive(ctx context.Context, tx *sql.Tx) error { func DownIsActive(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, ` _, err := tx.ExecContext(ctx, `
ALTER TABLE account_accounts RENAME TO account_accounts_tmp; ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
CREATE TABLE account_accounts ( CREATE TABLE userapi_accounts (
localpart TEXT NOT NULL PRIMARY KEY, localpart TEXT NOT NULL PRIMARY KEY,
created_ts BIGINT NOT NULL, created_ts BIGINT NOT NULL,
password_hash TEXT, password_hash TEXT,
appservice_id TEXT appservice_id TEXT
); );
INSERT INSERT
INTO account_accounts ( INTO userapi_accounts (
localpart, created_ts, password_hash, appservice_id localpart, created_ts, password_hash, appservice_id
) SELECT ) SELECT
localpart, created_ts, password_hash, appservice_id localpart, created_ts, password_hash, appservice_id
FROM account_accounts_tmp FROM userapi_accounts_tmp
; ;
DROP TABLE account_accounts_tmp;`) DROP TABLE userapi_accounts_tmp;`)
if err != nil { if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err) return fmt.Errorf("failed to execute downgrade: %w", err)
} }

View file

@ -8,8 +8,8 @@ import (
func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, ` _, err := tx.ExecContext(ctx, `
ALTER TABLE device_devices RENAME TO device_devices_tmp; ALTER TABLE userapi_devices RENAME TO userapi_devices_tmp;
CREATE TABLE device_devices ( CREATE TABLE userapi_devices (
access_token TEXT PRIMARY KEY, access_token TEXT PRIMARY KEY,
session_id INTEGER, session_id INTEGER,
device_id TEXT , device_id TEXT ,
@ -22,12 +22,12 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
UNIQUE (localpart, device_id) UNIQUE (localpart, device_id)
); );
INSERT INSERT
INTO device_devices ( INTO userapi_devices (
access_token, session_id, device_id, localpart, created_ts, display_name, last_seen_ts, ip, user_agent access_token, session_id, device_id, localpart, created_ts, display_name, last_seen_ts, ip, user_agent
) SELECT ) SELECT
access_token, session_id, device_id, localpart, created_ts, display_name, created_ts, '', '' access_token, session_id, device_id, localpart, created_ts, display_name, created_ts, '', ''
FROM device_devices_tmp; FROM userapi_devices_tmp;
DROP TABLE device_devices_tmp;`) DROP TABLE userapi_devices_tmp;`)
if err != nil { if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err) return fmt.Errorf("failed to execute upgrade: %w", err)
} }
@ -36,8 +36,8 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
func DownLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { func DownLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, ` _, err := tx.ExecContext(ctx, `
ALTER TABLE device_devices RENAME TO device_devices_tmp; ALTER TABLE userapi_devices RENAME TO userapi_devices_tmp;
CREATE TABLE IF NOT EXISTS device_devices ( CREATE TABLE IF NOT EXISTS userapi_devices (
access_token TEXT PRIMARY KEY, access_token TEXT PRIMARY KEY,
session_id INTEGER, session_id INTEGER,
device_id TEXT , device_id TEXT ,
@ -47,12 +47,12 @@ CREATE TABLE IF NOT EXISTS device_devices (
UNIQUE (localpart, device_id) UNIQUE (localpart, device_id)
); );
INSERT INSERT
INTO device_devices ( INTO userapi_devices (
access_token, session_id, device_id, localpart, created_ts, display_name access_token, session_id, device_id, localpart, created_ts, display_name
) SELECT ) SELECT
access_token, session_id, device_id, localpart, created_ts, display_name access_token, session_id, device_id, localpart, created_ts, display_name
FROM device_devices_tmp; FROM userapi_devices_tmp;
DROP TABLE device_devices_tmp;`) DROP TABLE userapi_devices_tmp;`)
if err != nil { if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err) return fmt.Errorf("failed to execute downgrade: %w", err)
} }

View file

@ -9,8 +9,8 @@ import (
func UpAddAccountType(ctx context.Context, tx *sql.Tx) error { func UpAddAccountType(ctx context.Context, tx *sql.Tx) error {
// initially set every account to useraccount, change appservice and guest accounts afterwards // initially set every account to useraccount, change appservice and guest accounts afterwards
// (user = 1, guest = 2, admin = 3, appservice = 4) // (user = 1, guest = 2, admin = 3, appservice = 4)
_, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts RENAME TO account_accounts_tmp; _, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
CREATE TABLE account_accounts ( CREATE TABLE userapi_accounts (
localpart TEXT NOT NULL PRIMARY KEY, localpart TEXT NOT NULL PRIMARY KEY,
created_ts BIGINT NOT NULL, created_ts BIGINT NOT NULL,
password_hash TEXT, password_hash TEXT,
@ -19,15 +19,15 @@ CREATE TABLE account_accounts (
account_type INTEGER NOT NULL account_type INTEGER NOT NULL
); );
INSERT INSERT
INTO account_accounts ( INTO userapi_accounts (
localpart, created_ts, password_hash, appservice_id, account_type localpart, created_ts, password_hash, appservice_id, account_type
) SELECT ) SELECT
localpart, created_ts, password_hash, appservice_id, 1 localpart, created_ts, password_hash, appservice_id, 1
FROM account_accounts_tmp FROM userapi_accounts_tmp
; ;
UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> ''; UPDATE userapi_accounts SET account_type = 4 WHERE appservice_id <> '';
UPDATE account_accounts SET account_type = 2 WHERE localpart GLOB '[0-9]*'; UPDATE userapi_accounts SET account_type = 2 WHERE localpart GLOB '[0-9]*';
DROP TABLE account_accounts_tmp;`) DROP TABLE userapi_accounts_tmp;`)
if err != nil { if err != nil {
return fmt.Errorf("failed to add column: %w", err) return fmt.Errorf("failed to add column: %w", err)
} }
@ -35,7 +35,7 @@ DROP TABLE account_accounts_tmp;`)
} }
func DownAddAccountType(ctx context.Context, tx *sql.Tx) error { func DownAddAccountType(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts DROP COLUMN account_type;`) _, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts DROP COLUMN account_type;`)
if err != nil { if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err) return fmt.Errorf("failed to execute downgrade: %w", err)
} }

View file

@ -0,0 +1,109 @@
package deltas
import (
"context"
"database/sql"
"fmt"
"strings"
)
var renameTableMappings = map[string]string{
"account_accounts": "userapi_accounts",
"account_data": "userapi_account_datas",
"device_devices": "userapi_devices",
"account_e2e_room_keys": "userapi_key_backups",
"account_e2e_room_keys_versions": "userapi_key_backup_versions",
"login_tokens": "userapi_login_tokens",
"open_id_tokens": "userapi_openid_tokens",
"account_profiles": "userapi_profiles",
"account_threepid": "userapi_threepids",
}
var renameIndicesMappings = map[string]string{
"device_localpart_id_idx": "userapi_device_localpart_id_idx",
"e2e_room_keys_idx": "userapi_key_backups_idx",
"e2e_room_keys_versions_idx": "userapi_key_backups_versions_idx",
"account_e2e_room_keys_versions_idx": "userapi_key_backup_versions_idx",
"login_tokens_expiration_idx": "userapi_login_tokens_expiration_idx",
"account_threepid_localpart": "userapi_threepid_idx",
}
func UpRenameTables(ctx context.Context, tx *sql.Tx) error {
for old, new := range renameTableMappings {
// SQLite has no "IF EXISTS" so check if the table exists.
var name string
if err := tx.QueryRowContext(
ctx, "SELECT name FROM sqlite_schema WHERE type = 'table' AND name = $1;", old,
).Scan(&name); err != nil {
if err == sql.ErrNoRows {
continue
}
return err
}
q := fmt.Sprintf(
"ALTER TABLE %s RENAME TO %s;", old, new,
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("rename table %q to %q error: %w", old, new, err)
}
}
for old, new := range renameIndicesMappings {
var query string
if err := tx.QueryRowContext(
ctx, "SELECT sql FROM sqlite_schema WHERE type = 'index' AND name = $1;", old,
).Scan(&query); err != nil {
if err == sql.ErrNoRows {
continue
}
return err
}
query = strings.Replace(query, old, new, 1)
if _, err := tx.ExecContext(ctx, fmt.Sprintf("DROP INDEX %s;", old)); err != nil {
return fmt.Errorf("drop index %q to %q error: %w", old, new, err)
}
if _, err := tx.ExecContext(ctx, query); err != nil {
return fmt.Errorf("recreate index %q to %q error: %w", old, new, err)
}
}
return nil
}
func DownRenameTables(ctx context.Context, tx *sql.Tx) error {
for old, new := range renameTableMappings {
// SQLite has no "IF EXISTS" so check if the table exists.
var name string
if err := tx.QueryRowContext(
ctx, "SELECT name FROM sqlite_schema WHERE type = 'table' AND name = $1;", new,
).Scan(&name); err != nil {
if err == sql.ErrNoRows {
continue
}
return err
}
q := fmt.Sprintf(
"ALTER TABLE %s RENAME TO %s;", new, old,
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("rename table %q to %q error: %w", new, old, err)
}
}
for old, new := range renameIndicesMappings {
var query string
if err := tx.QueryRowContext(
ctx, "SELECT sql FROM sqlite_schema WHERE type = 'index' AND name = $1;", new,
).Scan(&query); err != nil {
if err == sql.ErrNoRows {
continue
}
return err
}
query = strings.Replace(query, new, old, 1)
if _, err := tx.ExecContext(ctx, fmt.Sprintf("DROP INDEX %s;", new)); err != nil {
return fmt.Errorf("drop index %q to %q error: %w", new, old, err)
}
if _, err := tx.ExecContext(ctx, query); err != nil {
return fmt.Errorf("recreate index %q to %q error: %w", new, old, err)
}
}
return nil
}

View file

@ -35,7 +35,7 @@ const devicesSchema = `
-- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1; -- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
-- Stores data about devices. -- Stores data about devices.
CREATE TABLE IF NOT EXISTS device_devices ( CREATE TABLE IF NOT EXISTS userapi_devices (
access_token TEXT PRIMARY KEY, access_token TEXT PRIMARY KEY,
session_id INTEGER, session_id INTEGER,
device_id TEXT , device_id TEXT ,
@ -51,38 +51,38 @@ CREATE TABLE IF NOT EXISTS device_devices (
` `
const insertDeviceSQL = "" + const insertDeviceSQL = "" +
"INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" + "INSERT INTO userapi_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
const selectDevicesCountSQL = "" + const selectDevicesCountSQL = "" +
"SELECT COUNT(access_token) FROM device_devices" "SELECT COUNT(access_token) FROM userapi_devices"
const selectDeviceByTokenSQL = "" + const selectDeviceByTokenSQL = "" +
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" "SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" + const selectDeviceByIDSQL = "" +
"SELECT display_name, last_seen_ts, ip FROM device_devices WHERE localpart = $1 and device_id = $2" "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" + const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
const updateDeviceNameSQL = "" + const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
const deleteDeviceSQL = "" + const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" + const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2" "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
const deleteDevicesSQL = "" + const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id IN ($2)"
const selectDevicesByIDSQL = "" + const selectDevicesByIDSQL = "" +
"SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC" "SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
const updateDeviceLastSeen = "" + const updateDeviceLastSeen = "" +
"UPDATE device_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
type devicesStatements struct { type devicesStatements struct {
db *sql.DB db *sql.DB

View file

@ -26,7 +26,7 @@ import (
) )
const keyBackupTableSchema = ` const keyBackupTableSchema = `
CREATE TABLE IF NOT EXISTS account_e2e_room_keys ( CREATE TABLE IF NOT EXISTS userapi_key_backups (
user_id TEXT NOT NULL, user_id TEXT NOT NULL,
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
session_id TEXT NOT NULL, session_id TEXT NOT NULL,
@ -37,31 +37,31 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
is_verified BOOLEAN NOT NULL, is_verified BOOLEAN NOT NULL,
session_data TEXT NOT NULL session_data TEXT NOT NULL
); );
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version); CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON userapi_key_backups(user_id, room_id, session_id, version);
CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version); CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON userapi_key_backups(user_id, version);
` `
const insertBackupKeySQL = "" + const insertBackupKeySQL = "" +
"INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " + "INSERT INTO userapi_key_backups(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
const updateBackupKeySQL = "" + const updateBackupKeySQL = "" +
"UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " + "UPDATE userapi_key_backups SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " +
"WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8" "WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8"
const countKeysSQL = "" + const countKeysSQL = "" +
"SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2" "SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2"
const selectKeysSQL = "" + const selectKeysSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2" "WHERE user_id = $1 AND version = $2"
const selectKeysByRoomIDSQL = "" + const selectKeysByRoomIDSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2 AND room_id = $3" "WHERE user_id = $1 AND version = $2 AND room_id = $3"
const selectKeysByRoomIDAndSessionIDSQL = "" + const selectKeysByRoomIDAndSessionIDSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4" "WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4"
type keyBackupStatements struct { type keyBackupStatements struct {

View file

@ -27,7 +27,7 @@ import (
const keyBackupVersionTableSchema = ` const keyBackupVersionTableSchema = `
-- the metadata for each generation of encrypted e2e session backups -- the metadata for each generation of encrypted e2e session backups
CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions ( CREATE TABLE IF NOT EXISTS userapi_key_backup_versions (
user_id TEXT NOT NULL, user_id TEXT NOT NULL,
-- this means no 2 users will ever have the same version of e2e session backups which strictly -- this means no 2 users will ever have the same version of e2e session backups which strictly
-- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1. -- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1.
@ -38,26 +38,26 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions (
deleted INTEGER DEFAULT 0 NOT NULL deleted INTEGER DEFAULT 0 NOT NULL
); );
CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); CREATE UNIQUE INDEX IF NOT EXISTS userapi_key_backup_versions_idx ON userapi_key_backup_versions(user_id, version);
` `
const insertKeyBackupSQL = "" + const insertKeyBackupSQL = "" +
"INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version" "INSERT INTO userapi_key_backup_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version"
const updateKeyBackupAuthDataSQL = "" + const updateKeyBackupAuthDataSQL = "" +
"UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" "UPDATE userapi_key_backup_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3"
const updateKeyBackupETagSQL = "" + const updateKeyBackupETagSQL = "" +
"UPDATE account_e2e_room_keys_versions SET etag = $1 WHERE user_id = $2 AND version = $3" "UPDATE userapi_key_backup_versions SET etag = $1 WHERE user_id = $2 AND version = $3"
const deleteKeyBackupSQL = "" + const deleteKeyBackupSQL = "" +
"UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2" "UPDATE userapi_key_backup_versions SET deleted=1 WHERE user_id = $1 AND version = $2"
const selectKeyBackupSQL = "" + const selectKeyBackupSQL = "" +
"SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2" "SELECT algorithm, auth_data, etag, deleted FROM userapi_key_backup_versions WHERE user_id = $1 AND version = $2"
const selectLatestVersionSQL = "" + const selectLatestVersionSQL = "" +
"SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1" "SELECT MAX(version) FROM userapi_key_backup_versions WHERE user_id = $1"
type keyBackupVersionStatements struct { type keyBackupVersionStatements struct {
insertKeyBackupStmt *sql.Stmt insertKeyBackupStmt *sql.Stmt

View file

@ -32,7 +32,7 @@ type loginTokenStatements struct {
} }
const loginTokenSchema = ` const loginTokenSchema = `
CREATE TABLE IF NOT EXISTS login_tokens ( CREATE TABLE IF NOT EXISTS userapi_login_tokens (
-- The random value of the token issued to a user -- The random value of the token issued to a user
token TEXT NOT NULL PRIMARY KEY, token TEXT NOT NULL PRIMARY KEY,
-- When the token expires -- When the token expires
@ -43,17 +43,17 @@ CREATE TABLE IF NOT EXISTS login_tokens (
); );
-- This index allows efficient garbage collection of expired tokens. -- This index allows efficient garbage collection of expired tokens.
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON userapi_login_tokens(token_expires_at);
` `
const insertLoginTokenSQL = "" + const insertLoginTokenSQL = "" +
"INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)" "INSERT INTO userapi_login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
const deleteLoginTokenSQL = "" + const deleteLoginTokenSQL = "" +
"DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2" "DELETE FROM userapi_login_tokens WHERE token = $1 OR token_expires_at <= $2"
const selectLoginTokenSQL = "" + const selectLoginTokenSQL = "" +
"SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2" "SELECT user_id FROM userapi_login_tokens WHERE token = $1 AND token_expires_at > $2"
func NewSQLiteLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) { func NewSQLiteLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) {
s := &loginTokenStatements{} s := &loginTokenStatements{}
@ -78,7 +78,7 @@ func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx
// deleteByToken removes the named token. // deleteByToken removes the named token.
// //
// As a simple way to garbage-collect stale tokens, we also remove all expired tokens. // As a simple way to garbage-collect stale tokens, we also remove all expired tokens.
// The login_tokens_expiration_idx index should make that efficient. // The userapi_login_tokens_expiration_idx index should make that efficient.
func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error { func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error {
stmt := sqlutil.TxStmt(txn, s.deleteStmt) stmt := sqlutil.TxStmt(txn, s.deleteStmt)
res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) res, err := stmt.ExecContext(ctx, token, time.Now().UTC())

View file

@ -13,7 +13,7 @@ import (
const openIDTokenSchema = ` const openIDTokenSchema = `
-- Stores data about accounts. -- Stores data about accounts.
CREATE TABLE IF NOT EXISTS open_id_tokens ( CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
-- The value of the token issued to a user -- The value of the token issued to a user
token TEXT NOT NULL PRIMARY KEY, token TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID for this account -- The Matrix user ID for this account
@ -24,10 +24,10 @@ CREATE TABLE IF NOT EXISTS open_id_tokens (
` `
const insertOpenIDTokenSQL = "" + const insertOpenIDTokenSQL = "" +
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" "INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
const selectOpenIDTokenSQL = "" + const selectOpenIDTokenSQL = "" +
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" "SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
type openIDTokenStatements struct { type openIDTokenStatements struct {
db *sql.DB db *sql.DB

View file

@ -27,7 +27,7 @@ import (
const profilesSchema = ` const profilesSchema = `
-- Stores data about accounts profiles. -- Stores data about accounts profiles.
CREATE TABLE IF NOT EXISTS account_profiles ( CREATE TABLE IF NOT EXISTS userapi_profiles (
-- The Matrix user ID localpart for this account -- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY, localpart TEXT NOT NULL PRIMARY KEY,
-- The display name for this account -- The display name for this account
@ -38,19 +38,21 @@ CREATE TABLE IF NOT EXISTS account_profiles (
` `
const insertProfileSQL = "" + const insertProfileSQL = "" +
"INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" "INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
const selectProfileByLocalpartSQL = "" + const selectProfileByLocalpartSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1" "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
const setAvatarURLSQL = "" + const setAvatarURLSQL = "" +
"UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2" "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" +
" RETURNING display_name"
const setDisplayNameSQL = "" + const setDisplayNameSQL = "" +
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" +
" RETURNING avatar_url"
const selectProfilesBySearchSQL = "" + const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
type profilesStatements struct { type profilesStatements struct {
db *sql.DB db *sql.DB
@ -102,18 +104,40 @@ func (s *profilesStatements) SelectProfileByLocalpart(
func (s *profilesStatements) SetAvatarURL( func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
) (err error) { ) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{
Localpart: localpart,
AvatarURL: avatarURL,
}
old, err := s.SelectProfileByLocalpart(ctx, localpart)
if err != nil {
return old, false, err
}
if old.AvatarURL == avatarURL {
return old, false, nil
}
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
_, err = stmt.ExecContext(ctx, avatarURL, localpart) err = stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName)
return return profile, true, err
} }
func (s *profilesStatements) SetDisplayName( func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string, ctx context.Context, txn *sql.Tx, localpart string, displayName string,
) (err error) { ) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{
Localpart: localpart,
DisplayName: displayName,
}
old, err := s.SelectProfileByLocalpart(ctx, localpart)
if err != nil {
return old, false, err
}
if old.DisplayName == displayName {
return old, false, nil
}
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
_, err = stmt.ExecContext(ctx, displayName, localpart) err = stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL)
return return profile, true, err
} }
func (s *profilesStatements) SelectProfilesBySearch( func (s *profilesStatements) SelectProfilesBySearch(

View file

@ -46,7 +46,7 @@ CREATE INDEX IF NOT EXISTS userapi_daily_visits_localpart_timestamp_idx ON usera
const countUsersLastSeenAfterSQL = "" + const countUsersLastSeenAfterSQL = "" +
"SELECT COUNT(*) FROM (" + "SELECT COUNT(*) FROM (" +
" SELECT localpart FROM device_devices WHERE last_seen_ts > $1 " + " SELECT localpart FROM userapi_devices WHERE last_seen_ts > $1 " +
" GROUP BY localpart" + " GROUP BY localpart" +
" ) u" " ) u"
@ -63,7 +63,7 @@ R30Users counts the number of 30 day retained users, defined as:
const countR30UsersSQL = ` const countR30UsersSQL = `
SELECT platform, COUNT(*) FROM ( SELECT platform, COUNT(*) FROM (
SELECT users.localpart, platform, users.created_ts, MAX(uip.last_seen_ts) SELECT users.localpart, platform, users.created_ts, MAX(uip.last_seen_ts)
FROM account_accounts users FROM userapi_accounts users
INNER JOIN INNER JOIN
(SELECT (SELECT
localpart, last_seen_ts, localpart, last_seen_ts,
@ -76,7 +76,7 @@ SELECT platform, COUNT(*) FROM (
ELSE 'unknown' ELSE 'unknown'
END END
AS platform AS platform
FROM device_devices FROM userapi_devices
) uip ) uip
ON users.localpart = uip.localpart ON users.localpart = uip.localpart
AND users.account_type <> 4 AND users.account_type <> 4
@ -126,7 +126,7 @@ GROUP BY client_type
` `
const countUserByAccountTypeSQL = ` const countUserByAccountTypeSQL = `
SELECT COUNT(*) FROM account_accounts WHERE account_type IN ($1) SELECT COUNT(*) FROM userapi_accounts WHERE account_type IN ($1)
` `
// $1 = Guest AccountType // $1 = Guest AccountType
@ -139,7 +139,7 @@ SELECT user_type, COUNT(*) AS count FROM (
WHEN account_type = $4 AND appservice_id IS NULL THEN 'guest' WHEN account_type = $4 AND appservice_id IS NULL THEN 'guest'
WHEN account_type IN ($5) AND appservice_id IS NOT NULL THEN 'bridged' WHEN account_type IN ($5) AND appservice_id IS NOT NULL THEN 'bridged'
END AS user_type END AS user_type
FROM account_accounts FROM userapi_accounts
WHERE created_ts > $8 WHERE created_ts > $8
) AS t GROUP BY user_type ) AS t GROUP BY user_type
` `
@ -148,14 +148,14 @@ SELECT user_type, COUNT(*) AS count FROM (
const updateUserDailyVisitsSQL = ` const updateUserDailyVisitsSQL = `
INSERT INTO userapi_daily_visits(localpart, device_id, timestamp, user_agent) INSERT INTO userapi_daily_visits(localpart, device_id, timestamp, user_agent)
SELECT u.localpart, u.device_id, $1, MAX(u.user_agent) SELECT u.localpart, u.device_id, $1, MAX(u.user_agent)
FROM device_devices AS u FROM userapi_devices AS u
LEFT JOIN ( LEFT JOIN (
SELECT localpart, device_id, timestamp FROM userapi_daily_visits SELECT localpart, device_id, timestamp FROM userapi_daily_visits
WHERE timestamp = $1 WHERE timestamp = $1
) udv ) udv
ON u.localpart = udv.localpart AND u.device_id = udv.device_id ON u.localpart = udv.localpart AND u.device_id = udv.device_id
INNER JOIN device_devices d ON d.localpart = u.localpart INNER JOIN userapi_devices d ON d.localpart = u.localpart
INNER JOIN account_accounts a ON a.localpart = u.localpart INNER JOIN userapi_accounts a ON a.localpart = u.localpart
WHERE $2 <= d.last_seen_ts AND d.last_seen_ts < $3 WHERE $2 <= d.last_seen_ts AND d.last_seen_ts < $3
AND a.account_type in (1, 3) AND a.account_type in (1, 3)
GROUP BY u.localpart, u.device_id GROUP BY u.localpart, u.device_id

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/shared" "github.com/matrix-org/dendrite/userapi/storage/shared"
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
) )
// NewDatabase creates a new accounts and profiles database // NewDatabase creates a new accounts and profiles database
@ -34,6 +35,16 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "userapi: rename tables",
Up: deltas.UpRenameTables,
Down: deltas.DownRenameTables,
})
if err = m.Up(base.Context()); err != nil {
return nil, err
}
accountDataTable, err := NewSQLiteAccountDataTable(db) accountDataTable, err := NewSQLiteAccountDataTable(db)
if err != nil { if err != nil {
return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err) return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)

View file

@ -27,7 +27,7 @@ import (
const threepidSchema = ` const threepidSchema = `
-- Stores data about third party identifiers -- Stores data about third party identifiers
CREATE TABLE IF NOT EXISTS account_threepid ( CREATE TABLE IF NOT EXISTS userapi_threepids (
-- The third party identifier -- The third party identifier
threepid TEXT NOT NULL, threepid TEXT NOT NULL,
-- The 3PID medium -- The 3PID medium
@ -38,20 +38,20 @@ CREATE TABLE IF NOT EXISTS account_threepid (
PRIMARY KEY(threepid, medium) PRIMARY KEY(threepid, medium)
); );
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart); CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart);
` `
const selectLocalpartForThreePIDSQL = "" + const selectLocalpartForThreePIDSQL = "" +
"SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2" "SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
const selectThreePIDsForLocalpartSQL = "" + const selectThreePIDsForLocalpartSQL = "" +
"SELECT threepid, medium FROM account_threepid WHERE localpart = $1" "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1"
const insertThreePIDSQL = "" + const insertThreePIDSQL = "" +
"INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)" "INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)"
const deleteThreePIDSQL = "" + const deleteThreePIDSQL = "" +
"DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2" "DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
type threepidStatements struct { type threepidStatements struct {
db *sql.DB db *sql.DB

View file

@ -16,6 +16,7 @@ import (
"github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
@ -29,14 +30,18 @@ var (
) )
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
connStr, close := test.PrepareDBConnectionString(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{ db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr), ConnectionString: config.DataSource(connStr),
}, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server") }, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server")
if err != nil { if err != nil {
t.Fatalf("NewUserAPIDatabase returned %s", err) t.Fatalf("NewUserAPIDatabase returned %s", err)
} }
return db, close return db, func() {
close()
baseclose()
}
} }
// Tests storing and getting account data // Tests storing and getting account data
@ -377,15 +382,23 @@ func Test_Profile(t *testing.T) {
// set avatar & displayname // set avatar & displayname
wantProfile.DisplayName = "Alice" wantProfile.DisplayName = "Alice"
wantProfile.AvatarURL = "mxc://aliceAvatar" gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, "Alice")
err = db.SetDisplayName(ctx, aliceLocalpart, "Alice")
assert.NoError(t, err, "unable to set displayname")
err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
assert.NoError(t, err, "unable to set avatar url")
// verify profile
gotProfile, err = db.GetProfileByLocalpart(ctx, aliceLocalpart)
assert.NoError(t, err, "unable to get profile by localpart")
assert.Equal(t, wantProfile, gotProfile) assert.Equal(t, wantProfile, gotProfile)
assert.NoError(t, err, "unable to set displayname")
assert.True(t, changed)
wantProfile.AvatarURL = "mxc://aliceAvatar"
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
assert.NoError(t, err, "unable to set avatar url")
assert.Equal(t, wantProfile, gotProfile)
assert.True(t, changed)
// Setting the same avatar again doesn't change anything
wantProfile.AvatarURL = "mxc://aliceAvatar"
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
assert.NoError(t, err, "unable to set avatar url")
assert.Equal(t, wantProfile, gotProfile)
assert.False(t, changed)
// search profiles // search profiles
searchRes, err := db.SearchProfiles(ctx, "Alice", 2) searchRes, err := db.SearchProfiles(ctx, "Alice", 2)

View file

@ -84,8 +84,8 @@ type OpenIDTable interface {
type ProfileTable interface { type ProfileTable interface {
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error
SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (err error) SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, bool, error)
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (err error) SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, bool, error)
SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
} }

View file

@ -106,7 +106,7 @@ func mustUpdateDeviceLastSeen(
timestamp time.Time, timestamp time.Time,
) { ) {
t.Helper() t.Helper()
_, err := db.ExecContext(ctx, "UPDATE device_devices SET last_seen_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart) _, err := db.ExecContext(ctx, "UPDATE userapi_devices SET last_seen_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart)
if err != nil { if err != nil {
t.Fatalf("unable to update device last seen") t.Fatalf("unable to update device last seen")
} }
@ -119,7 +119,7 @@ func mustUserUpdateRegistered(
localpart string, localpart string,
timestamp time.Time, timestamp time.Time,
) { ) {
_, err := db.ExecContext(ctx, "UPDATE account_accounts SET created_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart) _, err := db.ExecContext(ctx, "UPDATE userapi_accounts SET created_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart)
if err != nil { if err != nil {
t.Fatalf("unable to update device last seen") t.Fatalf("unable to update device last seen")
} }

View file

@ -23,13 +23,15 @@ import (
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/dendrite/userapi/inthttp"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/dendrite/userapi/inthttp"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/dendrite/userapi/internal"
@ -48,9 +50,9 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap
if opts.loginTokenLifetime == 0 { if opts.loginTokenLifetime == 0 {
opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond
} }
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
connStr, close := test.PrepareDBConnectionString(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType)
accountDB, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
accountDB, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr), ConnectionString: config.DataSource(connStr),
}, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "") }, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "")
if err != nil { if err != nil {
@ -66,7 +68,10 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap
return &internal.UserInternalAPI{ return &internal.UserInternalAPI{
DB: accountDB, DB: accountDB,
ServerName: cfg.Matrix.ServerName, ServerName: cfg.Matrix.ServerName,
}, accountDB, close }, accountDB, func() {
close()
baseclose()
}
} }
func TestQueryProfile(t *testing.T) { func TestQueryProfile(t *testing.T) {
@ -79,10 +84,10 @@ func TestQueryProfile(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to make account: %s", err) t.Fatalf("failed to make account: %s", err)
} }
if err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil { if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil {
t.Fatalf("failed to set avatar url: %s", err) t.Fatalf("failed to set avatar url: %s", err)
} }
if err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil { if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil {
t.Fatalf("failed to set display name: %s", err) t.Fatalf("failed to set display name: %s", err)
} }