Merge pull request #44 from globekeeper/release/upstream_0.10.4

Release/upstream 0.10.4
This commit is contained in:
Daniel Aloni 2022-10-27 17:57:41 +03:00 committed by GitHub
commit bc17086f63
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
80 changed files with 2619 additions and 743 deletions

View file

@ -1,5 +1,23 @@
# Changelog
## Dendrite 0.10.4 (2022-10-21)
### Features
* Various tables belonging to the user API will be renamed so that they are namespaced with the `userapi_` prefix
* Note that, after upgrading to this version, you should not revert to an older version of Dendrite as the database changes **will not** be reverted automatically
* The backoff and retry behaviour in the federation API has been refactored and improved
### Fixes
* Private read receipt support is now advertised in the client `/versions` endpoint
* Private read receipts will now clear notification counts properly
* A bug where a false `leave` membership transition was inserted into the timeline after accepting an invite has been fixed
* Some panics caused by concurrent map writes in the key server have been fixed
* The sync API now calculates membership transitions from state deltas more accurately
* Transaction IDs are now scoped to endpoints, which should fix some bugs where transaction ID reuse could cause nonsensical cached responses from some endpoints
* The length of the `type`, `sender`, `state_key` and `room_id` fields in events are now verified by number of bytes rather than codepoints after a spec clarification, reverting a change made in Dendrite 0.9.6
## Dendrite 0.10.3 (2022-10-14)
### Features

View file

@ -19,6 +19,8 @@ import (
"net/http"
"time"
"github.com/matrix-org/gomatrixserverlib"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
@ -27,7 +29,6 @@ import (
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/util"
@ -120,20 +121,6 @@ func SetAvatarURL(
}
}
res := &userapi.QueryProfileResponse{}
err = profileAPI.QueryProfile(req.Context(), &userapi.QueryProfileRequest{
UserID: userID,
}, res)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.QueryProfile failed")
return jsonerror.InternalServerError()
}
oldProfile := &authtypes.Profile{
Localpart: localpart,
DisplayName: res.DisplayName,
AvatarURL: res.AvatarURL,
}
setRes := &userapi.PerformSetAvatarURLResponse{}
if err = profileAPI.SetAvatarURL(req.Context(), &userapi.PerformSetAvatarURLRequest{
Localpart: localpart,
@ -142,41 +129,17 @@ func SetAvatarURL(
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetAvatarURL failed")
return jsonerror.InternalServerError()
}
var roomsRes api.QueryRoomsForUserResponse
err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{
UserID: device.UserID,
WantMembership: "join",
}, &roomsRes)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed")
return jsonerror.InternalServerError()
}
newProfile := authtypes.Profile{
Localpart: localpart,
DisplayName: oldProfile.DisplayName,
AvatarURL: r.AvatarURL,
}
events, err := buildMembershipEvents(
req.Context(), roomsRes.RoomIDs, newProfile, userID, cfg, evTime, rsAPI,
)
switch e := err.(type) {
case nil:
case gomatrixserverlib.BadJSONError:
// No need to build new membership events, since nothing changed
if !setRes.Changed {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(e.Error()),
Code: http.StatusOK,
JSON: struct{}{},
}
default:
util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvents failed")
return jsonerror.InternalServerError()
}
if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
return jsonerror.InternalServerError()
response, err := updateProfile(req.Context(), rsAPI, device, setRes.Profile, userID, cfg, evTime)
if err != nil {
return response
}
return util.JSONResponse{
@ -249,47 +212,51 @@ func SetDisplayName(
}
}
pRes := &userapi.QueryProfileResponse{}
err = profileAPI.QueryProfile(req.Context(), &userapi.QueryProfileRequest{
UserID: userID,
}, pRes)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.QueryProfile failed")
return jsonerror.InternalServerError()
}
oldProfile := &authtypes.Profile{
Localpart: localpart,
DisplayName: pRes.DisplayName,
AvatarURL: pRes.AvatarURL,
}
profileRes := &userapi.PerformUpdateDisplayNameResponse{}
err = profileAPI.SetDisplayName(req.Context(), &userapi.PerformUpdateDisplayNameRequest{
Localpart: localpart,
DisplayName: r.DisplayName,
}, &struct{}{})
}, profileRes)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed")
return jsonerror.InternalServerError()
}
// No need to build new membership events, since nothing changed
if !profileRes.Changed {
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
response, err := updateProfile(req.Context(), rsAPI, device, profileRes.Profile, userID, cfg, evTime)
if err != nil {
return response
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
func updateProfile(
ctx context.Context, rsAPI api.ClientRoomserverAPI, device *userapi.Device,
profile *authtypes.Profile,
userID string, cfg *config.ClientAPI, evTime time.Time,
) (util.JSONResponse, error) {
var res api.QueryRoomsForUserResponse
err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{
err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
UserID: device.UserID,
WantMembership: "join",
}, &res)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed")
return jsonerror.InternalServerError()
}
newProfile := authtypes.Profile{
Localpart: localpart,
DisplayName: r.DisplayName,
AvatarURL: oldProfile.AvatarURL,
util.GetLogger(ctx).WithError(err).Error("QueryRoomsForUser failed")
return jsonerror.InternalServerError(), err
}
events, err := buildMembershipEvents(
req.Context(), res.RoomIDs, newProfile, userID, cfg, evTime, rsAPI,
ctx, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI,
)
switch e := err.(type) {
case nil:
@ -297,21 +264,17 @@ func SetDisplayName(
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(e.Error()),
}
}, e
default:
util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvents failed")
return jsonerror.InternalServerError()
util.GetLogger(ctx).WithError(err).Error("buildMembershipEvents failed")
return jsonerror.InternalServerError(), e
}
if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
return jsonerror.InternalServerError()
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil {
util.GetLogger(ctx).WithError(err).Error("SendEvents failed")
return jsonerror.InternalServerError(), err
}
return util.JSONResponse{}, nil
}
// getProfile gets the full profile of a user by querying the database or a

View file

@ -19,6 +19,9 @@ import (
"net/http"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/eventutil"
@ -26,8 +29,6 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
type redactionContent struct {
@ -51,7 +52,7 @@ func SendRedaction(
if txnID != nil {
// Try to fetch response from transactionsCache
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok {
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok {
return *res
}
}
@ -144,7 +145,7 @@ func SendRedaction(
// Add response to transactionsCache
if txnID != nil {
txnCache.AddTransaction(device.AccessToken, *txnID, &res)
txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res)
}
return res

View file

@ -72,6 +72,7 @@ func Setup(
unstableFeatures := map[string]bool{
"org.matrix.e2e_cross_signing": true,
"org.matrix.msc2285.stable": true,
}
for _, msc := range cfg.MSCs.MSCs {
unstableFeatures["org.matrix."+msc] = true
@ -179,7 +180,7 @@ func Setup(
// server notifications
if cfg.Matrix.ServerNotices.Enabled {
logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice")
serverNotificationSender, err := getSenderDevice(context.Background(), userAPI, cfg)
serverNotificationSender, err := getSenderDevice(context.Background(), rsAPI, userAPI, cfg)
if err != nil {
logrus.WithError(err).Fatal("unable to get account for sending sending server notices")
}

View file

@ -86,7 +86,7 @@ func SendEvent(
if txnID != nil {
// Try to fetch response from transactionsCache
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok {
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok {
return *res
}
}
@ -206,7 +206,7 @@ func SendEvent(
}
// Add response to transactionsCache
if txnID != nil {
txnCache.AddTransaction(device.AccessToken, *txnID, &res)
txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res)
}
// Take a note of how long it took to generate the event vs submit

View file

@ -16,12 +16,13 @@ import (
"encoding/json"
"net/http"
"github.com/matrix-org/util"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/internal/transactions"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
)
// SendToDevice handles PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId}
@ -33,7 +34,7 @@ func SendToDevice(
eventType string, txnID *string,
) util.JSONResponse {
if txnID != nil {
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok {
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok {
return *res
}
}
@ -63,7 +64,7 @@ func SendToDevice(
}
if txnID != nil {
txnCache.AddTransaction(device.AccessToken, *txnID, &res)
txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res)
}
return res

View file

@ -21,7 +21,6 @@ import (
"net/http"
"time"
"github.com/matrix-org/dendrite/roomserver/version"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/tokens"
@ -29,6 +28,8 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/roomserver/version"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
@ -73,7 +74,7 @@ func SendServerNotice(
if txnID != nil {
// Try to fetch response from transactionsCache
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok {
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok {
return *res
}
}
@ -251,7 +252,7 @@ func SendServerNotice(
}
// Add response to transactionsCache
if txnID != nil {
txnCache.AddTransaction(device.AccessToken, *txnID, &res)
txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res)
}
// Take a note of how long it took to generate the event vs submit
@ -276,6 +277,7 @@ func (r sendServerNoticeRequest) valid() (ok bool) {
// It returns an userapi.Device, which is used for building the event
func getSenderDevice(
ctx context.Context,
rsAPI api.ClientRoomserverAPI,
userAPI userapi.ClientUserAPI,
cfg *config.ClientAPI,
) (*userapi.Device, error) {
@ -290,16 +292,32 @@ func getSenderDevice(
return nil, err
}
// set the avatarurl for the user
res := &userapi.PerformSetAvatarURLResponse{}
// Set the avatarurl for the user
avatarRes := &userapi.PerformSetAvatarURLResponse{}
if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{
Localpart: cfg.Matrix.ServerNotices.LocalPart,
AvatarURL: cfg.Matrix.ServerNotices.AvatarURL,
}, res); err != nil {
}, avatarRes); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed")
return nil, err
}
profile := avatarRes.Profile
// Set the displayname for the user
displayNameRes := &userapi.PerformUpdateDisplayNameResponse{}
if err = userAPI.SetDisplayName(ctx, &userapi.PerformUpdateDisplayNameRequest{
Localpart: cfg.Matrix.ServerNotices.LocalPart,
DisplayName: cfg.Matrix.ServerNotices.DisplayName,
}, displayNameRes); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.SetDisplayName failed")
return nil, err
}
if displayNameRes.Changed {
profile.DisplayName = cfg.Matrix.ServerNotices.DisplayName
}
// Check if we got existing devices
deviceRes := &userapi.QueryDevicesResponse{}
err = userAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{
@ -309,7 +327,15 @@ func getSenderDevice(
return nil, err
}
// We've got an existing account, return the first device of it
if len(deviceRes.Devices) > 0 {
// If there were changes to the profile, create a new membership event
if displayNameRes.Changed || avatarRes.Changed {
_, err = updateProfile(ctx, rsAPI, &deviceRes.Devices[0], profile, accRes.Account.UserID, cfg, time.Now())
if err != nil {
return nil, err
}
}
return &deviceRes.Devices[0], nil
}

View file

@ -179,7 +179,10 @@ func sharedSecretRegister(sharedSecret, serverURL, localpart, password string, a
body, _ = io.ReadAll(regResp.Body)
return "", fmt.Errorf(gjson.GetBytes(body, "error").Str)
}
r, _ := io.ReadAll(regResp.Body)
r, err := io.ReadAll(regResp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response body (HTTP %d): %w", regResp.StatusCode, err)
}
return gjson.GetBytes(r, "access_token").Str, nil
}

View file

@ -231,9 +231,9 @@ GEM
jekyll-seo-tag (~> 2.1)
minitest (5.15.0)
multipart-post (2.1.1)
nokogiri (1.13.6-arm64-darwin)
nokogiri (1.13.9-arm64-darwin)
racc (~> 1.4)
nokogiri (1.13.6-x86_64-linux)
nokogiri (1.13.9-x86_64-linux)
racc (~> 1.4)
octokit (4.22.0)
faraday (>= 0.9)

View file

@ -116,17 +116,14 @@ func NewInternalAPI(
_ = federationDB.RemoveAllServersFromBlacklist()
}
stats := &statistics.Statistics{
DB: federationDB,
FailuresUntilBlacklist: cfg.FederationMaxRetries,
}
stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1)
js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
queues := queue.NewOutgoingQueues(
federationDB, base.ProcessContext,
cfg.Matrix.DisableFederation,
cfg.Matrix.ServerName, federation, rsAPI, stats,
cfg.Matrix.ServerName, federation, rsAPI, &stats,
&queue.SigningInfo{
KeyID: cfg.Matrix.KeyID,
PrivateKey: cfg.Matrix.PrivateKey,
@ -183,5 +180,5 @@ func NewInternalAPI(
}
time.AfterFunc(time.Minute, cleanExpiredEDUs)
return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, stats, caches, queues, keyRing)
return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, &stats, caches, queues, keyRing)
}

View file

@ -21,21 +21,22 @@ import (
"sync"
"time"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"go.uber.org/atomic"
fedapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"go.uber.org/atomic"
)
const (
maxPDUsPerTransaction = 50
maxEDUsPerTransaction = 50
maxEDUsPerTransaction = 100
maxPDUsInMemory = 128
maxEDUsInMemory = 128
queueIdleTimeout = time.Second * 30
@ -64,7 +65,6 @@ type destinationQueue struct {
pendingPDUs []*queuedPDU // PDUs waiting to be sent
pendingEDUs []*queuedEDU // EDUs waiting to be sent
pendingMutex sync.RWMutex // protects pendingPDUs and pendingEDUs
interruptBackoff chan bool // interrupts backoff
}
// Send event adds the event to the pending queue for the destination.
@ -75,21 +75,7 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re
logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination)
return
}
// Create a database entry that associates the given PDU NID with
// this destination queue. We'll then be able to retrieve the PDU
// later.
if err := oq.db.AssociatePDUWithDestination(
oq.process.Context(),
"", // TODO: remove this, as we don't need to persist the transaction ID
oq.destination, // the destination server name
receipt, // NIDs from federationapi_queue_json table
); err != nil {
logrus.WithError(err).Errorf("failed to associate PDU %q with destination %q", event.EventID(), oq.destination)
return
}
// Check if the destination is blacklisted. If it isn't then wake
// up the queue.
if !oq.statistics.Blacklisted() {
// If there's room in memory to hold the event then add it to the
// list.
oq.pendingMutex.Lock()
@ -102,12 +88,9 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re
oq.overflowed.Store(true)
}
oq.pendingMutex.Unlock()
// Wake up the queue if it's asleep.
oq.wakeQueueIfNeeded()
select {
case oq.notify <- struct{}{}:
default:
}
if !oq.backingOff.Load() {
oq.wakeQueueAndNotify()
}
}
@ -119,22 +102,7 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share
logrus.Errorf("attempt to send nil EDU with destination %q", oq.destination)
return
}
// Create a database entry that associates the given PDU NID with
// this destination queue. We'll then be able to retrieve the PDU
// later.
if err := oq.db.AssociateEDUWithDestination(
oq.process.Context(),
oq.destination, // the destination server name
receipt, // NIDs from federationapi_queue_json table
event.Type,
nil, // this will use the default expireEDUTypes map
); err != nil {
logrus.WithError(err).Errorf("failed to associate EDU with destination %q", oq.destination)
return
}
// Check if the destination is blacklisted. If it isn't then wake
// up the queue.
if !oq.statistics.Blacklisted() {
// If there's room in memory to hold the event then add it to the
// list.
oq.pendingMutex.Lock()
@ -147,24 +115,47 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share
oq.overflowed.Store(true)
}
oq.pendingMutex.Unlock()
if !oq.backingOff.Load() {
oq.wakeQueueAndNotify()
}
}
// handleBackoffNotifier is registered as the backoff notification
// callback with Statistics. It will wakeup and notify the queue
// if the queue is currently backing off.
func (oq *destinationQueue) handleBackoffNotifier() {
// Only wake up the queue if it is backing off.
// Otherwise there is no pending work for the queue to handle
// so waking the queue would be a waste of resources.
if oq.backingOff.Load() {
oq.wakeQueueAndNotify()
}
}
// wakeQueueAndNotify ensures the destination queue is running and notifies it
// that there is pending work.
func (oq *destinationQueue) wakeQueueAndNotify() {
// Wake up the queue if it's asleep.
oq.wakeQueueIfNeeded()
// Notify the queue that there are events ready to send.
select {
case oq.notify <- struct{}{}:
default:
}
}
}
// wakeQueueIfNeeded will wake up the destination queue if it is
// not already running. If it is running but it is backing off
// then we will interrupt the backoff, causing any federation
// requests to retry.
func (oq *destinationQueue) wakeQueueIfNeeded() {
// If we are backing off then interrupt the backoff.
// Clear the backingOff flag and update the backoff metrics if it was set.
if oq.backingOff.CompareAndSwap(true, false) {
oq.interruptBackoff <- true
destinationQueueBackingOff.Dec()
}
// If we aren't running then wake up the queue.
if !oq.running.Load() {
// Start the queue.
@ -196,38 +187,54 @@ func (oq *destinationQueue) getPendingFromDatabase() {
gotEDUs[edu.receipt.String()] = struct{}{}
}
overflowed := false
if pduCapacity := maxPDUsInMemory - len(oq.pendingPDUs); pduCapacity > 0 {
// We have room in memory for some PDUs - let's request no more than that.
if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, pduCapacity); err == nil {
if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, maxPDUsInMemory); err == nil {
if len(pdus) == maxPDUsInMemory {
overflowed = true
}
for receipt, pdu := range pdus {
if _, ok := gotPDUs[receipt.String()]; ok {
continue
}
oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{receipt, pdu})
retrieved = true
if len(oq.pendingPDUs) == maxPDUsInMemory {
break
}
}
} else {
logrus.WithError(err).Errorf("Failed to get pending PDUs for %q", oq.destination)
}
}
if eduCapacity := maxEDUsInMemory - len(oq.pendingEDUs); eduCapacity > 0 {
// We have room in memory for some EDUs - let's request no more than that.
if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, eduCapacity); err == nil {
if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, maxEDUsInMemory); err == nil {
if len(edus) == maxEDUsInMemory {
overflowed = true
}
for receipt, edu := range edus {
if _, ok := gotEDUs[receipt.String()]; ok {
continue
}
oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{receipt, edu})
retrieved = true
if len(oq.pendingEDUs) == maxEDUsInMemory {
break
}
}
} else {
logrus.WithError(err).Errorf("Failed to get pending EDUs for %q", oq.destination)
}
}
// If we've retrieved all of the events from the database with room to spare
// in memory then we'll no longer consider this queue to be overflowed.
if len(oq.pendingPDUs) < maxPDUsInMemory && len(oq.pendingEDUs) < maxEDUsInMemory {
if !overflowed {
oq.overflowed.Store(false)
} else {
}
// If we've retrieved some events then notify the destination queue goroutine.
if retrieved {
@ -238,6 +245,24 @@ func (oq *destinationQueue) getPendingFromDatabase() {
}
}
// checkNotificationsOnClose checks for any remaining notifications
// and starts a new backgroundSend goroutine if any exist.
func (oq *destinationQueue) checkNotificationsOnClose() {
// NOTE : If we are stopping the queue due to blacklist then it
// doesn't matter if we have been notified of new work since
// this queue instance will be deleted anyway.
if !oq.statistics.Blacklisted() {
select {
case <-oq.notify:
// We received a new notification in between the
// idle timeout firing and stopping the goroutine.
// Immediately restart the queue.
oq.wakeQueueAndNotify()
default:
}
}
}
// backgroundSend is the worker goroutine for sending events.
func (oq *destinationQueue) backgroundSend() {
// Check if a worker is already running, and if it isn't, then
@ -245,10 +270,17 @@ func (oq *destinationQueue) backgroundSend() {
if !oq.running.CompareAndSwap(false, true) {
return
}
// Register queue cleanup functions.
// NOTE : The ordering here is very intentional.
defer oq.checkNotificationsOnClose()
defer oq.running.Store(false)
destinationQueueRunning.Inc()
defer destinationQueueRunning.Dec()
defer oq.queues.clearQueue(oq)
defer oq.running.Store(false)
idleTimeout := time.NewTimer(queueIdleTimeout)
defer idleTimeout.Stop()
// Mark the queue as overflowed, so we will consult the database
// to see if there's anything new to send.
@ -261,59 +293,33 @@ func (oq *destinationQueue) backgroundSend() {
oq.getPendingFromDatabase()
}
// Reset the queue idle timeout.
if !idleTimeout.Stop() {
select {
case <-idleTimeout.C:
default:
}
}
idleTimeout.Reset(queueIdleTimeout)
// If we have nothing to do then wait either for incoming events, or
// until we hit an idle timeout.
select {
case <-oq.notify:
// There's work to do, either because getPendingFromDatabase
// told us there is, or because a new event has come in via
// sendEvent/sendEDU.
case <-time.After(queueIdleTimeout):
// told us there is, a new event has come in via sendEvent/sendEDU,
// or we are backing off and it is time to retry.
case <-idleTimeout.C:
// The worker is idle so stop the goroutine. It'll get
// restarted automatically the next time we have an event to
// send.
return
case <-oq.process.Context().Done():
// The parent process is shutting down, so stop.
oq.statistics.ClearBackoff()
return
}
// If we are backing off this server then wait for the
// backoff duration to complete first, or until explicitly
// told to retry.
until, blacklisted := oq.statistics.BackoffInfo()
if blacklisted {
// It's been suggested that we should give up because the backoff
// has exceeded a maximum allowable value. Clean up the in-memory
// buffers at this point. The PDU clean-up is already on a defer.
logrus.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination)
oq.pendingMutex.Lock()
for i := range oq.pendingPDUs {
oq.pendingPDUs[i] = nil
}
for i := range oq.pendingEDUs {
oq.pendingEDUs[i] = nil
}
oq.pendingPDUs = nil
oq.pendingEDUs = nil
oq.pendingMutex.Unlock()
return
}
if until != nil && until.After(time.Now()) {
// We haven't backed off yet, so wait for the suggested amount of
// time.
duration := time.Until(*until)
logrus.Debugf("Backing off %q for %s", oq.destination, duration)
oq.backingOff.Store(true)
destinationQueueBackingOff.Inc()
select {
case <-time.After(duration):
case <-oq.interruptBackoff:
}
destinationQueueBackingOff.Dec()
oq.backingOff.Store(false)
}
// Work out which PDUs/EDUs to include in the next transaction.
oq.pendingMutex.RLock()
pduCount := len(oq.pendingPDUs)
@ -328,99 +334,52 @@ func (oq *destinationQueue) backgroundSend() {
toSendEDUs := oq.pendingEDUs[:eduCount]
oq.pendingMutex.RUnlock()
// If we didn't get anything from the database and there are no
// pending EDUs then there's nothing to do - stop here.
if pduCount == 0 && eduCount == 0 {
continue
}
// If we have pending PDUs or EDUs then construct a transaction.
// Try sending the next transaction and see what happens.
transaction, pc, ec, terr := oq.nextTransaction(toSendPDUs, toSendEDUs)
terr := oq.nextTransaction(toSendPDUs, toSendEDUs)
if terr != nil {
// We failed to send the transaction. Mark it as a failure.
oq.statistics.Failure()
} else if transaction {
// If we successfully sent the transaction then clear out
// the pending events and EDUs, and wipe our transaction ID.
oq.statistics.Success()
oq.pendingMutex.Lock()
for i := range oq.pendingPDUs[:pc] {
oq.pendingPDUs[i] = nil
_, blacklisted := oq.statistics.Failure()
if !blacklisted {
// Register the backoff state and exit the goroutine.
// It'll get restarted automatically when the backoff
// completes.
oq.backingOff.Store(true)
destinationQueueBackingOff.Inc()
return
} else {
// Immediately trigger the blacklist logic.
oq.blacklistDestination()
return
}
for i := range oq.pendingEDUs[:ec] {
oq.pendingEDUs[i] = nil
}
oq.pendingPDUs = oq.pendingPDUs[pc:]
oq.pendingEDUs = oq.pendingEDUs[ec:]
oq.pendingMutex.Unlock()
} else {
oq.handleTransactionSuccess(pduCount, eduCount)
}
}
}
// nextTransaction creates a new transaction from the pending event
// queue and sends it. Returns true if a transaction was sent or
// false otherwise.
// queue and sends it.
// Returns an error if the transaction wasn't sent.
func (oq *destinationQueue) nextTransaction(
pdus []*queuedPDU,
edus []*queuedEDU,
) (bool, int, int, error) {
// If there's no projected transaction ID then generate one. If
// the transaction succeeds then we'll set it back to "" so that
// we generate a new one next time. If it fails, we'll preserve
// it so that we retry with the same transaction ID.
oq.transactionIDMutex.Lock()
if oq.transactionID == "" {
now := gomatrixserverlib.AsTimestamp(time.Now())
oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
}
oq.transactionIDMutex.Unlock()
) error {
// Create the transaction.
t := gomatrixserverlib.Transaction{
PDUs: []json.RawMessage{},
EDUs: []gomatrixserverlib.EDU{},
}
t.Origin = oq.origin
t.Destination = oq.destination
t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now())
t.TransactionID = oq.transactionID
// If we didn't get anything from the database and there are no
// pending EDUs then there's nothing to do - stop here.
if len(pdus) == 0 && len(edus) == 0 {
return false, 0, 0, nil
}
var pduReceipts []*shared.Receipt
var eduReceipts []*shared.Receipt
// Go through PDUs that we retrieved from the database, if any,
// and add them into the transaction.
for _, pdu := range pdus {
if pdu == nil || pdu.pdu == nil {
continue
}
// Append the JSON of the event, since this is a json.RawMessage type in the
// gomatrixserverlib.Transaction struct
t.PDUs = append(t.PDUs, pdu.pdu.JSON())
pduReceipts = append(pduReceipts, pdu.receipt)
}
// Do the same for pending EDUS in the queue.
for _, edu := range edus {
if edu == nil || edu.edu == nil {
continue
}
t.EDUs = append(t.EDUs, *edu.edu)
eduReceipts = append(eduReceipts, edu.receipt)
}
t, pduReceipts, eduReceipts := oq.createTransaction(pdus, edus)
logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs))
// Try to send the transaction to the destination server.
// TODO: we should check for 500-ish fails vs 400-ish here,
// since we shouldn't queue things indefinitely in response
// to a 400-ish error
ctx, cancel := context.WithTimeout(oq.process.Context(), time.Minute*5)
defer cancel()
_, err := oq.client.SendTransaction(ctx, t)
switch err.(type) {
switch errResponse := err.(type) {
case nil:
// Clean up the transaction in the database.
if pduReceipts != nil {
@ -439,16 +398,129 @@ func (oq *destinationQueue) nextTransaction(
oq.transactionIDMutex.Lock()
oq.transactionID = ""
oq.transactionIDMutex.Unlock()
return true, len(t.PDUs), len(t.EDUs), nil
return nil
case gomatrix.HTTPError:
// Report that we failed to send the transaction and we
// will retry again, subject to backoff.
return false, 0, 0, err
// TODO: we should check for 500-ish fails vs 400-ish here,
// since we shouldn't queue things indefinitely in response
// to a 400-ish error
code := errResponse.Code
logrus.Debug("Transaction failed with HTTP", code)
return err
default:
logrus.WithFields(logrus.Fields{
"destination": oq.destination,
logrus.ErrorKey: err,
}).Debugf("Failed to send transaction %q", t.TransactionID)
return false, 0, 0, err
return err
}
}
// createTransaction generates a gomatrixserverlib.Transaction from the provided pdus and edus.
// It also returns the associated event receipts so they can be cleaned from the database in
// the case of a successful transaction.
func (oq *destinationQueue) createTransaction(
pdus []*queuedPDU,
edus []*queuedEDU,
) (gomatrixserverlib.Transaction, []*shared.Receipt, []*shared.Receipt) {
// If there's no projected transaction ID then generate one. If
// the transaction succeeds then we'll set it back to "" so that
// we generate a new one next time. If it fails, we'll preserve
// it so that we retry with the same transaction ID.
oq.transactionIDMutex.Lock()
if oq.transactionID == "" {
now := gomatrixserverlib.AsTimestamp(time.Now())
oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
}
oq.transactionIDMutex.Unlock()
t := gomatrixserverlib.Transaction{
PDUs: []json.RawMessage{},
EDUs: []gomatrixserverlib.EDU{},
}
t.Origin = oq.origin
t.Destination = oq.destination
t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now())
t.TransactionID = oq.transactionID
var pduReceipts []*shared.Receipt
var eduReceipts []*shared.Receipt
// Go through PDUs that we retrieved from the database, if any,
// and add them into the transaction.
for _, pdu := range pdus {
// These should never be nil.
if pdu == nil || pdu.pdu == nil {
continue
}
// Append the JSON of the event, since this is a json.RawMessage type in the
// gomatrixserverlib.Transaction struct
t.PDUs = append(t.PDUs, pdu.pdu.JSON())
pduReceipts = append(pduReceipts, pdu.receipt)
}
// Do the same for pending EDUS in the queue.
for _, edu := range edus {
// These should never be nil.
if edu == nil || edu.edu == nil {
continue
}
t.EDUs = append(t.EDUs, *edu.edu)
eduReceipts = append(eduReceipts, edu.receipt)
}
return t, pduReceipts, eduReceipts
}
// blacklistDestination removes all pending PDUs and EDUs that have been cached
// and deletes this queue.
func (oq *destinationQueue) blacklistDestination() {
// It's been suggested that we should give up because the backoff
// has exceeded a maximum allowable value. Clean up the in-memory
// buffers at this point. The PDU clean-up is already on a defer.
logrus.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination)
oq.pendingMutex.Lock()
for i := range oq.pendingPDUs {
oq.pendingPDUs[i] = nil
}
for i := range oq.pendingEDUs {
oq.pendingEDUs[i] = nil
}
oq.pendingPDUs = nil
oq.pendingEDUs = nil
oq.pendingMutex.Unlock()
// Delete this queue as no more messages will be sent to this
// destination until it is no longer blacklisted.
oq.statistics.AssignBackoffNotifier(nil)
oq.queues.clearQueue(oq)
}
// handleTransactionSuccess updates the cached event queues as well as the success and
// backoff information for this server.
func (oq *destinationQueue) handleTransactionSuccess(pduCount int, eduCount int) {
// If we successfully sent the transaction then clear out
// the pending events and EDUs, and wipe our transaction ID.
oq.statistics.Success()
oq.pendingMutex.Lock()
defer oq.pendingMutex.Unlock()
for i := range oq.pendingPDUs[:pduCount] {
oq.pendingPDUs[i] = nil
}
for i := range oq.pendingEDUs[:eduCount] {
oq.pendingEDUs[i] = nil
}
oq.pendingPDUs = oq.pendingPDUs[pduCount:]
oq.pendingEDUs = oq.pendingEDUs[eduCount:]
if len(oq.pendingPDUs) > 0 || len(oq.pendingEDUs) > 0 {
select {
case oq.notify <- struct{}{}:
default:
}
}
}

View file

@ -24,6 +24,7 @@ import (
"github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
@ -171,14 +172,16 @@ func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *d
client: oqs.client,
statistics: oqs.statistics.ForServer(destination),
notify: make(chan struct{}, 1),
interruptBackoff: make(chan bool),
signing: oqs.signing,
}
oq.statistics.AssignBackoffNotifier(oq.handleBackoffNotifier)
oqs.queues[destination] = oq
}
return oq
}
// clearQueue removes the queue for the provided destination from the
// set of destination queues.
func (oqs *OutgoingQueues) clearQueue(oq *destinationQueue) {
oqs.queuesMutex.Lock()
defer oqs.queuesMutex.Unlock()
@ -245,11 +248,25 @@ func (oqs *OutgoingQueues) SendEvent(
}
for destination := range destmap {
if queue := oqs.getQueue(destination); queue != nil {
if queue := oqs.getQueue(destination); queue != nil && !queue.statistics.Blacklisted() {
queue.sendEvent(ev, nid)
} else {
delete(destmap, destination)
}
}
// Create a database entry that associates the given PDU NID with
// this destinations queue. We'll then be able to retrieve the PDU
// later.
if err := oqs.db.AssociatePDUWithDestinations(
oqs.process.Context(),
destmap,
nid, // NIDs from federationapi_queue_json table
); err != nil {
logrus.WithError(err).Errorf("failed to associate PDUs %q with destinations", nid)
return err
}
return nil
}
@ -319,11 +336,27 @@ func (oqs *OutgoingQueues) SendEDU(
}
for destination := range destmap {
if queue := oqs.getQueue(destination); queue != nil {
if queue := oqs.getQueue(destination); queue != nil && !queue.statistics.Blacklisted() {
queue.sendEDU(e, nid)
} else {
delete(destmap, destination)
}
}
// Create a database entry that associates the given PDU NID with
// this destination queue. We'll then be able to retrieve the PDU
// later.
if err := oqs.db.AssociateEDUWithDestinations(
oqs.process.Context(),
destmap, // the destination server name
nid, // NIDs from federationapi_queue_json table
e.Type,
nil, // this will use the default expireEDUTypes map
); err != nil {
logrus.WithError(err).Errorf("failed to associate EDU with destinations")
return err
}
return nil
}
@ -332,7 +365,9 @@ func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) {
if oqs.disabled {
return
}
oqs.statistics.ForServer(srv).RemoveBlacklist()
if queue := oqs.getQueue(srv); queue != nil {
queue.statistics.ClearBackoff()
queue.wakeQueueIfNeeded()
}
}

File diff suppressed because it is too large Load diff

View file

@ -2,6 +2,7 @@ package statistics
import (
"math"
"math/rand"
"sync"
"time"
@ -20,12 +21,23 @@ type Statistics struct {
servers map[gomatrixserverlib.ServerName]*ServerStatistics
mutex sync.RWMutex
backoffTimers map[gomatrixserverlib.ServerName]*time.Timer
backoffMutex sync.RWMutex
// How many times should we tolerate consecutive failures before we
// just blacklist the host altogether? The backoff is exponential,
// so the max time here to attempt is 2**failures seconds.
FailuresUntilBlacklist uint32
}
func NewStatistics(db storage.Database, failuresUntilBlacklist uint32) Statistics {
return Statistics{
DB: db,
FailuresUntilBlacklist: failuresUntilBlacklist,
backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer),
}
}
// ForServer returns server statistics for the given server name. If it
// does not exist, it will create empty statistics and return those.
func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerStatistics {
@ -45,7 +57,6 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS
server = &ServerStatistics{
statistics: s,
serverName: serverName,
interrupt: make(chan struct{}),
}
s.servers[serverName] = server
s.mutex.Unlock()
@ -70,23 +81,37 @@ type ServerStatistics struct {
backoffStarted atomic.Bool // is the backoff started
backoffUntil atomic.Value // time.Time until this backoff interval ends
backoffCount atomic.Uint32 // number of times BackoffDuration has been called
interrupt chan struct{} // interrupts the backoff goroutine
successCounter atomic.Uint32 // how many times have we succeeded?
backoffNotifier func() // notifies destination queue when backoff completes
notifierMutex sync.Mutex
}
const maxJitterMultiplier = 1.4
const minJitterMultiplier = 0.8
// duration returns how long the next backoff interval should be.
func (s *ServerStatistics) duration(count uint32) time.Duration {
return time.Second * time.Duration(math.Exp2(float64(count)))
// Add some jitter to minimise the chance of having multiple backoffs
// ending at the same time.
jitter := rand.Float64()*(maxJitterMultiplier-minJitterMultiplier) + minJitterMultiplier
duration := time.Millisecond * time.Duration(math.Exp2(float64(count))*jitter*1000)
return duration
}
// cancel will interrupt the currently active backoff.
func (s *ServerStatistics) cancel() {
s.blacklisted.Store(false)
s.backoffUntil.Store(time.Time{})
select {
case s.interrupt <- struct{}{}:
default:
s.ClearBackoff()
}
// AssignBackoffNotifier configures the channel to send to when
// a backoff completes.
func (s *ServerStatistics) AssignBackoffNotifier(notifier func()) {
s.notifierMutex.Lock()
defer s.notifierMutex.Unlock()
s.backoffNotifier = notifier
}
// Success updates the server statistics with a new successful
@ -95,8 +120,8 @@ func (s *ServerStatistics) cancel() {
// we will unblacklist it.
func (s *ServerStatistics) Success() {
s.cancel()
s.successCounter.Inc()
s.backoffCount.Store(0)
s.successCounter.Inc()
if s.statistics.DB != nil {
if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil {
logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName)
@ -105,13 +130,17 @@ func (s *ServerStatistics) Success() {
}
// Failure marks a failure and starts backing off if needed.
// The next call to BackoffIfRequired will do the right thing
// after this. It will return the time that the current failure
// It will return the time that the current failure
// will result in backoff waiting until, and a bool signalling
// whether we have blacklisted and therefore to give up.
func (s *ServerStatistics) Failure() (time.Time, bool) {
// Return immediately if we have blacklisted this node.
if s.blacklisted.Load() {
return time.Time{}, true
}
// If we aren't already backing off, this call will start
// a new backoff period. Increase the failure counter and
// a new backoff period, increase the failure counter and
// start a goroutine which will wait out the backoff and
// unset the backoffStarted flag when done.
if s.backoffStarted.CompareAndSwap(false, true) {
@ -122,40 +151,48 @@ func (s *ServerStatistics) Failure() (time.Time, bool) {
logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName)
}
}
s.ClearBackoff()
return time.Time{}, true
}
go func() {
until, ok := s.backoffUntil.Load().(time.Time)
if ok && !until.IsZero() {
select {
case <-time.After(time.Until(until)):
case <-s.interrupt:
}
s.backoffStarted.Store(false)
}
}()
}
// Check if we have blacklisted this node.
if s.blacklisted.Load() {
return time.Now(), true
}
// If we're already backing off and we haven't yet surpassed
// the deadline then return that. Repeated calls to Failure
// within a single backoff interval will have no side effects.
if until, ok := s.backoffUntil.Load().(time.Time); ok && !time.Now().After(until) {
return until, false
}
// We're either backing off and have passed the deadline, or
// we aren't backing off, so work out what the next interval
// We're starting a new back off so work out what the next interval
// will be.
count := s.backoffCount.Load()
until := time.Now().Add(s.duration(count))
s.backoffUntil.Store(until)
return until, false
s.statistics.backoffMutex.Lock()
defer s.statistics.backoffMutex.Unlock()
s.statistics.backoffTimers[s.serverName] = time.AfterFunc(time.Until(until), s.backoffFinished)
}
return s.backoffUntil.Load().(time.Time), false
}
// ClearBackoff stops the backoff timer for this destination if it is running
// and removes the timer from the backoffTimers map.
func (s *ServerStatistics) ClearBackoff() {
// If the timer is still running then stop it so it's memory is cleaned up sooner.
s.statistics.backoffMutex.Lock()
defer s.statistics.backoffMutex.Unlock()
if timer, ok := s.statistics.backoffTimers[s.serverName]; ok {
timer.Stop()
}
delete(s.statistics.backoffTimers, s.serverName)
s.backoffStarted.Store(false)
}
// backoffFinished will clear the previous backoff and notify the destination queue.
func (s *ServerStatistics) backoffFinished() {
s.ClearBackoff()
// Notify the destinationQueue if one is currently running.
s.notifierMutex.Lock()
defer s.notifierMutex.Unlock()
if s.backoffNotifier != nil {
s.backoffNotifier()
}
}
// BackoffInfo returns information about the current or previous backoff.
@ -174,6 +211,12 @@ func (s *ServerStatistics) Blacklisted() bool {
return s.blacklisted.Load()
}
// RemoveBlacklist removes the blacklisted status from the server.
func (s *ServerStatistics) RemoveBlacklist() {
s.cancel()
s.backoffCount.Store(0)
}
// SuccessCount returns the number of successful requests. This is
// usually useful in constructing transaction IDs.
func (s *ServerStatistics) SuccessCount() uint32 {

View file

@ -7,9 +7,7 @@ import (
)
func TestBackoff(t *testing.T) {
stats := Statistics{
FailuresUntilBlacklist: 7,
}
stats := NewStatistics(nil, 7)
server := ServerStatistics{
statistics: &stats,
serverName: "test.com",
@ -36,7 +34,7 @@ func TestBackoff(t *testing.T) {
// Get the duration.
_, blacklist := server.BackoffInfo()
duration := time.Until(until).Round(time.Second)
duration := time.Until(until)
// Unset the backoff, or otherwise our next call will think that
// there's a backoff in progress and return the same result.
@ -57,8 +55,17 @@ func TestBackoff(t *testing.T) {
// Check if the duration is what we expect.
t.Logf("Backoff %d is for %s", i, duration)
if wanted := time.Second * time.Duration(math.Exp2(float64(i))); !blacklist && duration != wanted {
t.Fatalf("Backoff %d should have been %s but was %s", i, wanted, duration)
roundingAllowance := 0.01
minDuration := time.Millisecond * time.Duration(math.Exp2(float64(i))*minJitterMultiplier*1000-roundingAllowance)
maxDuration := time.Millisecond * time.Duration(math.Exp2(float64(i))*maxJitterMultiplier*1000+roundingAllowance)
var inJitterRange bool
if duration >= minDuration && duration <= maxDuration {
inJitterRange = true
} else {
inJitterRange = false
}
if !blacklist && !inJitterRange {
t.Fatalf("Backoff %d should have been between %s and %s but was %s", i, minDuration, maxDuration, duration)
}
}
}

View file

@ -18,9 +18,10 @@ import (
"context"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/federationapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
type Database interface {
@ -38,8 +39,8 @@ type Database interface {
GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error)
GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error)
AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error
AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error
AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error
AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error
CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error

View file

@ -52,6 +52,10 @@ type Receipt struct {
nid int64
}
func NewReceipt(nid int64) Receipt {
return Receipt{nid: nid}
}
func (r *Receipt) String() string {
return fmt.Sprintf("%d", r.nid)
}

View file

@ -38,9 +38,9 @@ var defaultExpireEDUTypes = map[string]time.Duration{
// AssociateEDUWithDestination creates an association that the
// destination queues will use to determine which JSON blobs to send
// to which servers.
func (d *Database) AssociateEDUWithDestination(
func (d *Database) AssociateEDUWithDestinations(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
destinations map[gomatrixserverlib.ServerName]struct{},
receipt *Receipt,
eduType string,
expireEDUTypes map[string]time.Duration,
@ -59,17 +59,18 @@ func (d *Database) AssociateEDUWithDestination(
expiresAt = 0
}
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.FederationQueueEDUs.InsertQueueEDU(
var err error
for destination := range destinations {
err = d.FederationQueueEDUs.InsertQueueEDU(
ctx, // context
txn, // SQL transaction
eduType, // EDU type for coalescing
serverName, // destination server name
destination, // destination server name
receipt.nid, // NID from the federationapi_queue_json table
expiresAt, // The timestamp this EDU will expire
); err != nil {
return fmt.Errorf("InsertQueueEDU: %w", err)
)
}
return nil
return err
})
}

View file

@ -27,23 +27,23 @@ import (
// AssociatePDUWithDestination creates an association that the
// destination queues will use to determine which JSON blobs to send
// to which servers.
func (d *Database) AssociatePDUWithDestination(
func (d *Database) AssociatePDUWithDestinations(
ctx context.Context,
transactionID gomatrixserverlib.TransactionID,
serverName gomatrixserverlib.ServerName,
destinations map[gomatrixserverlib.ServerName]struct{},
receipt *Receipt,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.FederationQueuePDUs.InsertQueuePDU(
var err error
for destination := range destinations {
err = d.FederationQueuePDUs.InsertQueuePDU(
ctx, // context
txn, // SQL transaction
transactionID, // transaction ID
serverName, // destination server name
"", // transaction ID
destination, // destination server name
receipt.nid, // NID from the federationapi_queue_json table
); err != nil {
return fmt.Errorf("InsertQueuePDU: %w", err)
)
}
return nil
return err
})
}

View file

@ -35,6 +35,7 @@ func TestExpireEDUs(t *testing.T) {
}
ctx := context.Background()
destinations := map[gomatrixserverlib.ServerName]struct{}{"localhost": {}}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateFederationDatabase(t, dbType)
defer close()
@ -43,7 +44,7 @@ func TestExpireEDUs(t *testing.T) {
receipt, err := db.StoreJSON(ctx, "{}")
assert.NoError(t, err)
err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, gomatrixserverlib.MReceipt, expireEDUTypes)
err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, gomatrixserverlib.MReceipt, expireEDUTypes)
assert.NoError(t, err)
}
// add data without expiry
@ -51,7 +52,7 @@ func TestExpireEDUs(t *testing.T) {
assert.NoError(t, err)
// m.read_marker gets the default expiry of 24h, so won't be deleted further down in this test
err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, "m.read_marker", expireEDUTypes)
err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, "m.read_marker", expireEDUTypes)
assert.NoError(t, err)
// Delete expired EDUs
@ -67,7 +68,7 @@ func TestExpireEDUs(t *testing.T) {
receipt, err = db.StoreJSON(ctx, "{}")
assert.NoError(t, err)
err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, gomatrixserverlib.MDirectToDevice, expireEDUTypes)
err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, gomatrixserverlib.MDirectToDevice, expireEDUTypes)
assert.NoError(t, err)
err = db.DeleteExpiredEDUs(ctx)

4
go.mod
View file

@ -22,7 +22,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20221014061925-a132619fa241
github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a
github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.15
@ -50,6 +50,7 @@ require (
golang.org/x/term v0.0.0-20220919170432-7a66f970e087
gopkg.in/h2non/bimg.v1 v1.1.9
gopkg.in/yaml.v2 v2.4.0
gotest.tools/v3 v3.4.0
nhooyr.io/websocket v1.8.7
)
@ -129,7 +130,6 @@ require (
gopkg.in/macaroon.v2 v2.1.0 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gotest.tools/v3 v3.4.0 // indirect
)
go 1.18

4
go.sum
View file

@ -385,8 +385,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20221014061925-a132619fa241 h1:e5o68MWeU7wjTvvNKmVo655oCYesoNRoPeBb1Xfz54g=
github.com/matrix-org/gomatrixserverlib v0.0.0-20221014061925-a132619fa241/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4=
github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a h1:6rJFN5NBuzZ7h5meYkLtXKa6VFZfDc8oVXHd4SDXr5o=
github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4=
github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3 h1:lzkSQvBv8TuqKJCPoVwOVvEnARTlua5rrNy/Qw2Vxeo=
github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=

View file

@ -13,6 +13,8 @@
package transactions
import (
"net/url"
"path/filepath"
"sync"
"time"
@ -29,6 +31,7 @@ type txnsMap map[CacheKey]*util.JSONResponse
type CacheKey struct {
AccessToken string
TxnID string
Endpoint string
}
// Cache represents a temporary store for response entries.
@ -57,14 +60,14 @@ func NewWithCleanupPeriod(cleanupPeriod time.Duration) *Cache {
return &t
}
// FetchTransaction looks up an entry for the (accessToken, txnID) tuple in Cache.
// FetchTransaction looks up an entry for the (accessToken, txnID, req.URL) tuple in Cache.
// Looks in both the txnMaps.
// Returns (JSON response, true) if txnID is found, else the returned bool is false.
func (t *Cache) FetchTransaction(accessToken, txnID string) (*util.JSONResponse, bool) {
func (t *Cache) FetchTransaction(accessToken, txnID string, u *url.URL) (*util.JSONResponse, bool) {
t.RLock()
defer t.RUnlock()
for _, txns := range t.txnsMaps {
res, ok := txns[CacheKey{accessToken, txnID}]
res, ok := txns[CacheKey{accessToken, txnID, filepath.Dir(u.Path)}]
if ok {
return res, true
}
@ -72,13 +75,12 @@ func (t *Cache) FetchTransaction(accessToken, txnID string) (*util.JSONResponse,
return nil, false
}
// AddTransaction adds an entry for the (accessToken, txnID) tuple in Cache.
// AddTransaction adds an entry for the (accessToken, txnID, req.URL) tuple in Cache.
// Adds to the front txnMap.
func (t *Cache) AddTransaction(accessToken, txnID string, res *util.JSONResponse) {
func (t *Cache) AddTransaction(accessToken, txnID string, u *url.URL, res *util.JSONResponse) {
t.Lock()
defer t.Unlock()
t.txnsMaps[0][CacheKey{accessToken, txnID}] = res
t.txnsMaps[0][CacheKey{accessToken, txnID, filepath.Dir(u.Path)}] = res
}
// cacheCleanService is responsible for cleaning up entries after cleanupPeriod.

View file

@ -14,6 +14,9 @@ package transactions
import (
"net/http"
"net/url"
"path/filepath"
"reflect"
"strconv"
"testing"
@ -24,6 +27,16 @@ type fakeType struct {
ID string `json:"ID"`
}
func TestCompare(t *testing.T) {
u1, _ := url.Parse("/send/1?accessToken=123")
u2, _ := url.Parse("/send/1")
c1 := CacheKey{"1", "2", filepath.Dir(u1.Path)}
c2 := CacheKey{"1", "2", filepath.Dir(u2.Path)}
if !reflect.DeepEqual(c1, c2) {
t.Fatalf("Cache keys differ: %+v <> %+v", c1, c2)
}
}
var (
fakeAccessToken = "aRandomAccessToken"
fakeAccessToken2 = "anotherRandomAccessToken"
@ -34,23 +47,28 @@ var (
fakeResponse2 = &util.JSONResponse{
Code: http.StatusOK, JSON: fakeType{ID: "1"},
}
fakeResponse3 = &util.JSONResponse{
Code: http.StatusOK, JSON: fakeType{ID: "2"},
}
)
// TestCache creates a New Cache and tests AddTransaction & FetchTransaction
func TestCache(t *testing.T) {
fakeTxnCache := New()
fakeTxnCache.AddTransaction(fakeAccessToken, fakeTxnID, fakeResponse)
u, _ := url.Parse("")
fakeTxnCache.AddTransaction(fakeAccessToken, fakeTxnID, u, fakeResponse)
// Add entries for noise.
for i := 1; i <= 100; i++ {
fakeTxnCache.AddTransaction(
fakeAccessToken,
fakeTxnID+strconv.Itoa(i),
u,
&util.JSONResponse{Code: http.StatusOK, JSON: fakeType{ID: strconv.Itoa(i)}},
)
}
testResponse, ok := fakeTxnCache.FetchTransaction(fakeAccessToken, fakeTxnID)
testResponse, ok := fakeTxnCache.FetchTransaction(fakeAccessToken, fakeTxnID, u)
if !ok {
t.Error("Failed to retrieve entry for txnID: ", fakeTxnID)
} else if testResponse.JSON != fakeResponse.JSON {
@ -59,20 +77,30 @@ func TestCache(t *testing.T) {
}
// TestCacheScope ensures transactions with the same transaction ID are not shared
// across multiple access tokens.
// across multiple access tokens and endpoints.
func TestCacheScope(t *testing.T) {
cache := New()
cache.AddTransaction(fakeAccessToken, fakeTxnID, fakeResponse)
cache.AddTransaction(fakeAccessToken2, fakeTxnID, fakeResponse2)
sendEndpoint, _ := url.Parse("/send/1?accessToken=test")
sendToDeviceEndpoint, _ := url.Parse("/sendToDevice/1")
cache.AddTransaction(fakeAccessToken, fakeTxnID, sendEndpoint, fakeResponse)
cache.AddTransaction(fakeAccessToken2, fakeTxnID, sendEndpoint, fakeResponse2)
cache.AddTransaction(fakeAccessToken2, fakeTxnID, sendToDeviceEndpoint, fakeResponse3)
if res, ok := cache.FetchTransaction(fakeAccessToken, fakeTxnID); !ok {
if res, ok := cache.FetchTransaction(fakeAccessToken, fakeTxnID, sendEndpoint); !ok {
t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID)
} else if res.JSON != fakeResponse.JSON {
t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse.JSON, res.JSON)
}
if res, ok := cache.FetchTransaction(fakeAccessToken2, fakeTxnID); !ok {
if res, ok := cache.FetchTransaction(fakeAccessToken2, fakeTxnID, sendEndpoint); !ok {
t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID)
} else if res.JSON != fakeResponse2.JSON {
t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse2.JSON, res.JSON)
}
// Ensure the txnID is not shared across endpoints
if res, ok := cache.FetchTransaction(fakeAccessToken2, fakeTxnID, sendToDeviceEndpoint); !ok {
t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID)
} else if res.JSON != fakeResponse3.JSON {
t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse2.JSON, res.JSON)
}
}

View file

@ -17,7 +17,7 @@ var build string
const (
VersionMajor = 0
VersionMinor = 10
VersionPatch = 3
VersionPatch = 4
VersionTag = "" // example: "rc1"
)

View file

@ -250,6 +250,7 @@ func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *ap
// nolint:gocyclo
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error {
var respMu sync.Mutex
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
@ -329,7 +330,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
}
// attempt to satisfy key queries from the local database first as we should get device updates pushed to us
domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys)
domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, &respMu, domainToDeviceKeys)
if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 {
// perform key queries for remote devices
a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys)
@ -407,7 +408,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
}
func (a *KeyInternalAPI) remoteKeysFromDatabase(
ctx context.Context, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string,
ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string,
) map[string]map[string][]string {
fetchRemote := make(map[string]map[string][]string)
for domain, userToDeviceMap := range domainToDeviceKeys {
@ -415,7 +416,7 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase(
// we can't safely return keys from the db when all devices are requested as we don't
// know if one has just been added.
if len(deviceIDs) > 0 {
err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, deviceIDs)
err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, deviceIDs)
if err == nil {
continue
}
@ -471,7 +472,9 @@ func (a *KeyInternalAPI) queryRemoteKeys(
close(resultCh)
}()
for result := range resultCh {
processResult := func(result *gomatrixserverlib.RespQueryKeys) {
respMu.Lock()
defer respMu.Unlock()
for userID, nest := range result.DeviceKeys {
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
for deviceID, deviceKey := range nest {
@ -494,6 +497,10 @@ func (a *KeyInternalAPI) queryRemoteKeys(
// TODO: do we want to persist these somewhere now
// that we have fetched them?
}
for result := range resultCh {
processResult(result)
}
}
func (a *KeyInternalAPI) queryRemoteKeysOnServer(
@ -541,9 +548,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
}
// refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this
// user so the fact that we're populating all devices here isn't a problem so long as we have devices.
respMu.Lock()
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, nil)
respMu.Unlock()
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil)
if err != nil {
logrus.WithFields(logrus.Fields{
logrus.ErrorKey: err,
@ -567,25 +572,26 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
res.Failures[serverName] = map[string]interface{}{
"message": err.Error(),
}
respMu.Unlock()
// last ditch, use the cache only. This is good for when clients hit /keys/query and the remote server
// is down, better to return something than nothing at all. Clients can know about the failure by
// inspecting the failures map though so they can know it's a cached response.
for userID, dkeys := range devKeys {
// drop the error as it's already a failure at this point
_ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, dkeys)
_ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, dkeys)
}
// Sytest expects no failures, if we still could retrieve keys, e.g. from local cache
respMu.Lock()
if len(res.DeviceKeys) > 0 {
delete(res.Failures, serverName)
}
respMu.Unlock()
}
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string,
ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string,
) error {
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
// if we can't query the db or there are fewer keys than requested, fetch from remote.
@ -598,9 +604,11 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
if len(deviceIDs) == 0 && len(keys) == 0 {
return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID)
}
respMu.Lock()
if res.DeviceKeys[userID] == nil {
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
}
respMu.Unlock()
for _, key := range keys {
if len(key.KeyJSON) == 0 {
@ -610,7 +618,9 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
DisplayName string `json:"device_display_name,omitempty"`
}{key.DisplayName})
respMu.Lock()
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
respMu.Unlock()
}
return nil
}

View file

@ -428,6 +428,13 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
return
}
// Only notify clients about retired invite events, if the user didn't accept the invite.
// The PDU stream will also receive an event about accepting the invitation, so there should
// be a "smooth" transition from invite -> join, and not invite -> leave -> join
if msg.Membership == gomatrixserverlib.Join {
return
}
// Notify any active sync requests that the invite has been retired.
s.inviteStream.Advance(pduPos)
s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, msg.TargetUserID)

View file

@ -28,8 +28,9 @@ import (
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
const outputRoomEventsSchema = `
@ -133,7 +134,7 @@ const updateEventJSONSQL = "" +
"UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2"
// In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id).
const selectStateInRangeSQL = "" +
const selectStateInRangeFilteredSQL = "" +
"SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids, history_visibility" +
" FROM syncapi_output_room_events" +
" WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" +
@ -146,6 +147,15 @@ const selectStateInRangeSQL = "" +
" ORDER BY id ASC" +
" LIMIT $9"
// In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id).
const selectStateInRangeSQL = "" +
"SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids, history_visibility" +
" FROM syncapi_output_room_events" +
" WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" +
" AND room_id = ANY($3)" +
" ORDER BY id ASC" +
" LIMIT $4"
const deleteEventsForRoomSQL = "" +
"DELETE FROM syncapi_output_room_events WHERE room_id = $1"
@ -178,6 +188,7 @@ type outputRoomEventsStatements struct {
selectRecentEventsStmt *sql.Stmt
selectRecentEventsForSyncStmt *sql.Stmt
selectEarlyEventsStmt *sql.Stmt
selectStateInRangeFilteredStmt *sql.Stmt
selectStateInRangeStmt *sql.Stmt
updateEventJSONStmt *sql.Stmt
deleteEventsForRoomStmt *sql.Stmt
@ -214,6 +225,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
{&s.selectRecentEventsStmt, selectRecentEventsSQL},
{&s.selectRecentEventsForSyncStmt, selectRecentEventsForSyncSQL},
{&s.selectEarlyEventsStmt, selectEarlyEventsSQL},
{&s.selectStateInRangeFilteredStmt, selectStateInRangeFilteredSQL},
{&s.selectStateInRangeStmt, selectStateInRangeSQL},
{&s.updateEventJSONStmt, updateEventJSONSQL},
{&s.deleteEventsForRoomStmt, deleteEventsForRoomSQL},
@ -240,9 +252,12 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
ctx context.Context, txn *sql.Tx, r types.Range,
stateFilter *gomatrixserverlib.StateFilter, roomIDs []string,
) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt)
var rows *sql.Rows
var err error
if stateFilter != nil {
stmt := sqlutil.TxStmt(txn, s.selectStateInRangeFilteredStmt)
senders, notSenders := getSendersStateFilterFilter(stateFilter)
rows, err := stmt.QueryContext(
rows, err = stmt.QueryContext(
ctx, r.Low(), r.High(), pq.StringArray(roomIDs),
pq.StringArray(senders),
pq.StringArray(notSenders),
@ -251,6 +266,14 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
stateFilter.ContainsURL,
stateFilter.Limit,
)
} else {
stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt)
rows, err = stmt.QueryContext(
ctx, r.Low(), r.High(), pq.StringArray(roomIDs),
r.High()-r.Low(),
)
}
if err != nil {
return nil, nil, err
}

View file

@ -5,10 +5,11 @@ import (
"database/sql"
"fmt"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
type DatabaseTransaction struct {
@ -277,6 +278,7 @@ func (d *DatabaseTransaction) GetBackwardTopologyPos(
// exclusive of oldPos, inclusive of newPos, for the rooms in which
// the user has new membership events.
// A list of joined room IDs is also returned in case the caller needs it.
// nolint:gocyclo
func (d *DatabaseTransaction) GetStateDeltas(
ctx context.Context, device *userapi.Device,
r types.Range, userID string,
@ -311,7 +313,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
}
// get all the state events ever (i.e. for all available rooms) between these two positions
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs)
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, nil, allRoomIDs)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
@ -326,6 +328,22 @@ func (d *DatabaseTransaction) GetStateDeltas(
return nil, nil, err
}
// get all the state events ever (i.e. for all available rooms) between these two positions
stateNeededFiltered, eventMapFiltered, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
}
return nil, nil, err
}
stateFiltered, err := d.fetchStateEvents(ctx, d.txn, stateNeededFiltered, eventMapFiltered)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
}
return nil, nil, err
}
// find out which rooms this user is peeking, if any.
// We do this before joins so any peeks get overwritten
peeks, err := d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, device.ID, r)
@ -371,6 +389,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
// If our membership is now join but the previous membership wasn't
// then this is a "join transition", so we'll insert this room.
if prevMembership != membership {
newlyJoinedRooms[roomID] = true
// Get the full room state, as we'll send that down for a newly
// joined room instead of a delta.
var s []types.StreamEvent
@ -383,8 +402,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
// Add the information for this room into the state so that
// it will get added with all of the rest of the joined rooms.
state[roomID] = s
newlyJoinedRooms[roomID] = true
stateFiltered[roomID] = s
}
// We won't add joined rooms into the delta at this point as they
@ -395,7 +413,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
deltas = append(deltas, types.StateDelta{
Membership: membership,
MembershipPos: ev.StreamPosition,
StateEvents: d.StreamEventsToEvents(device, stateStreamEvents),
StateEvents: d.StreamEventsToEvents(device, stateFiltered[roomID]),
RoomID: roomID,
})
break
@ -407,7 +425,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
for _, joinedRoomID := range joinedRoomIDs {
deltas = append(deltas, types.StateDelta{
Membership: gomatrixserverlib.Join,
StateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]),
StateEvents: d.StreamEventsToEvents(device, stateFiltered[joinedRoomID]),
RoomID: joinedRoomID,
NewlyJoined: newlyJoinedRooms[joinedRoomID],
})

View file

@ -29,8 +29,9 @@ import (
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
const outputRoomEventsSchema = `
@ -189,21 +190,36 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
for _, roomID := range roomIDs {
inputParams = append(inputParams, roomID)
}
stmt, params, err := prepareWithFilters(
var (
stmt *sql.Stmt
params []any
err error
)
if stateFilter != nil {
stmt, params, err = prepareWithFilters(
s.db, txn, stmtSQL, inputParams,
stateFilter.Senders, stateFilter.NotSenders,
stateFilter.Types, stateFilter.NotTypes,
nil, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderAsc,
)
} else {
stmt, params, err = prepareWithFilters(
s.db, txn, stmtSQL, inputParams,
nil, nil,
nil, nil,
nil, nil, int(r.High()-r.Low()), FilterOrderAsc,
)
}
if err != nil {
return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err)
}
defer internal.CloseAndLogIfError(ctx, stmt, "selectStateInRange: stmt.close() failed")
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, nil, err
}
defer rows.Close() // nolint: errcheck
defer internal.CloseAndLogIfError(ctx, rows, "selectStateInRange: rows.close() failed")
// Fetch all the state change events for all rooms between the two positions then loop each event and:
// - Keep a cache of the event by ID (99% of state change events are for the event itself)
// - For each room ID, build up an array of event IDs which represents cumulative adds/removes
@ -269,6 +285,7 @@ func (s *outputRoomEventsStatements) SelectMaxEventID(
) (id int64, err error) {
var nullableID sql.NullInt64
stmt := sqlutil.TxStmt(txn, s.selectMaxEventIDStmt)
defer internal.CloseAndLogIfError(ctx, stmt, "SelectMaxEventID: stmt.close() failed")
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
if nullableID.Valid {
id = nullableID.Int64
@ -323,6 +340,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
return 0, err
}
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
defer internal.CloseAndLogIfError(ctx, insertStmt, "InsertEvent: stmt.close() failed")
_, err = insertStmt.ExecContext(
ctx,
streamPos,
@ -367,6 +385,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
if err != nil {
return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err)
}
defer internal.CloseAndLogIfError(ctx, stmt, "selectRecentEvents: stmt.close() failed")
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
@ -415,6 +434,8 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
if err != nil {
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
}
defer internal.CloseAndLogIfError(ctx, stmt, "SelectEarlyEvents: stmt.close() failed")
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, err
@ -456,6 +477,8 @@ func (s *outputRoomEventsStatements) SelectEvents(
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, stmt, "SelectEvents: stmt.close() failed")
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, err
@ -558,6 +581,10 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
filter.Types, filter.NotTypes,
nil, filter.ContainsURL, filter.Limit, FilterOrderDesc,
)
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, stmt, "SelectContextBeforeEvent: stmt.close() failed")
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
@ -596,6 +623,10 @@ func (s *outputRoomEventsStatements) SelectContextAfterEvent(
filter.Types, filter.NotTypes,
nil, filter.ContainsURL, filter.Limit, FilterOrderAsc,
)
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, stmt, "SelectContextAfterEvent: stmt.close() failed")
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {

View file

@ -0,0 +1,198 @@
package tables_test
import (
"context"
"database/sql"
"reflect"
"sort"
"testing"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test"
)
func newMembershipsTable(t *testing.T, dbType test.DBType) (tables.Memberships, *sql.DB, func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
if err != nil {
t.Fatalf("failed to open db: %s", err)
}
var tab tables.Memberships
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresMembershipsTable(db)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSqliteMembershipsTable(db)
}
if err != nil {
t.Fatalf("failed to make new table: %s", err)
}
return tab, db, close
}
func TestMembershipsTable(t *testing.T) {
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
// Create users
var userEvents []*gomatrixserverlib.HeaderedEvent
users := []string{alice.ID}
for _, x := range room.CurrentState() {
if x.StateKeyEquals(alice.ID) {
if _, err := x.Membership(); err == nil {
userEvents = append(userEvents, x)
break
}
}
}
if len(userEvents) == 0 {
t.Fatalf("didn't find creator membership event")
}
for i := 0; i < 10; i++ {
u := test.NewUser(t)
users = append(users, u.ID)
ev := room.CreateAndInsert(t, u, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "join",
}, test.WithStateKey(u.ID))
userEvents = append(userEvents, ev)
}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
table, _, close := newMembershipsTable(t, dbType)
defer close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
for _, ev := range userEvents {
if err := table.UpsertMembership(ctx, nil, ev, types.StreamPosition(ev.Depth()), 1); err != nil {
t.Fatalf("failed to upsert membership: %s", err)
}
}
testUpsert(t, ctx, table, userEvents[0], alice, room)
testMembershipCount(t, ctx, table, room)
testHeroes(t, ctx, table, alice, room, users)
})
}
func testHeroes(t *testing.T, ctx context.Context, table tables.Memberships, user *test.User, room *test.Room, users []string) {
// Re-slice and sort the expected users
users = users[1:]
sort.Strings(users)
type testCase struct {
name string
memberships []string
wantHeroes []string
}
testCases := []testCase{
{name: "no memberships queried", memberships: []string{}},
{name: "joined memberships queried should be limited", memberships: []string{gomatrixserverlib.Join}, wantHeroes: users[:5]},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, err := table.SelectHeroes(ctx, nil, room.ID, user.ID, tc.memberships)
if err != nil {
t.Fatalf("unable to select heroes: %s", err)
}
if gotLen := len(got); gotLen != len(tc.wantHeroes) {
t.Fatalf("expected %d heroes, got %d", len(tc.wantHeroes), gotLen)
}
if !reflect.DeepEqual(got, tc.wantHeroes) {
t.Fatalf("expected heroes to be %+v, got %+v", tc.wantHeroes, got)
}
})
}
}
func testMembershipCount(t *testing.T, ctx context.Context, table tables.Memberships, room *test.Room) {
t.Run("membership counts are correct", func(t *testing.T) {
// After 10 events, we should have 6 users (5 create related [incl. one member event], 5 member events = 6 users)
count, err := table.SelectMembershipCount(ctx, nil, room.ID, gomatrixserverlib.Join, 10)
if err != nil {
t.Fatalf("failed to get membership count: %s", err)
}
expectedCount := 6
if expectedCount != count {
t.Fatalf("expected member count to be %d, got %d", expectedCount, count)
}
// After 100 events, we should have all 11 users
count, err = table.SelectMembershipCount(ctx, nil, room.ID, gomatrixserverlib.Join, 100)
if err != nil {
t.Fatalf("failed to get membership count: %s", err)
}
expectedCount = 11
if expectedCount != count {
t.Fatalf("expected member count to be %d, got %d", expectedCount, count)
}
})
}
func testUpsert(t *testing.T, ctx context.Context, table tables.Memberships, membershipEvent *gomatrixserverlib.HeaderedEvent, user *test.User, room *test.Room) {
t.Run("upserting works as expected", func(t *testing.T) {
if err := table.UpsertMembership(ctx, nil, membershipEvent, 1, 1); err != nil {
t.Fatalf("failed to upsert membership: %s", err)
}
membership, pos, err := table.SelectMembershipForUser(ctx, nil, room.ID, user.ID, 1)
if err != nil {
t.Fatalf("failed to select membership: %s", err)
}
expectedPos := 1
if pos != expectedPos {
t.Fatalf("expected pos to be %d, got %d", expectedPos, pos)
}
if membership != gomatrixserverlib.Join {
t.Fatalf("expected membership to be join, got %s", membership)
}
// Create a new event which gets upserted and should not cause issues
ev := room.CreateAndInsert(t, user, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": gomatrixserverlib.Join,
}, test.WithStateKey(user.ID))
// Insert the same event again, but with different positions, which should get updated
if err = table.UpsertMembership(ctx, nil, ev, 2, 2); err != nil {
t.Fatalf("failed to upsert membership: %s", err)
}
// Verify the position got updated
membership, pos, err = table.SelectMembershipForUser(ctx, nil, room.ID, user.ID, 10)
if err != nil {
t.Fatalf("failed to select membership: %s", err)
}
expectedPos = 2
if pos != expectedPos {
t.Fatalf("expected pos to be %d, got %d", expectedPos, pos)
}
if membership != gomatrixserverlib.Join {
t.Fatalf("expected membership to be join, got %s", membership)
}
// If we can't find a membership, it should default to leave
if membership, _, err = table.SelectMembershipForUser(ctx, nil, room.ID, user.ID, 1); err != nil {
t.Fatalf("failed to select membership: %s", err)
}
if membership != gomatrixserverlib.Leave {
t.Fatalf("expected membership to be leave, got %s", membership)
}
})
}

View file

@ -74,7 +74,12 @@ func (p *InviteStreamProvider) IncrementalSync(
return to
}
for roomID := range retiredInvites {
if _, ok := req.Response.Rooms.Join[roomID]; !ok {
if _, ok := req.Response.Rooms.Invite[roomID]; ok {
continue
}
if _, ok := req.Response.Rooms.Join[roomID]; ok {
continue
}
lr := types.NewLeaveResponse()
h := sha256.Sum256(append([]byte(roomID), []byte(strconv.FormatInt(int64(to), 10))...))
lr.Timeline.Events = append(lr.Timeline.Events, gomatrixserverlib.ClientEvent{
@ -88,7 +93,7 @@ func (p *InviteStreamProvider) IncrementalSync(
Content: gomatrixserverlib.RawJSON(`{"membership":"leave"}`),
})
req.Response.Rooms.Leave[roomID] = lr
}
}
return maxID

View file

@ -194,7 +194,7 @@ func (p *PDUStreamProvider) IncrementalSync(
}
}
var pos types.StreamPosition
if pos, err = p.addRoomDeltaToResponse(ctx, snapshot, req.Device, newRange, delta, &eventFilter, &stateFilter, req.Response); err != nil {
if pos, err = p.addRoomDeltaToResponse(ctx, snapshot, req.Device, newRange, delta, &eventFilter, &stateFilter, req); err != nil {
req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed")
if err == context.DeadlineExceeded || err == context.Canceled || err == sql.ErrTxDone {
return newPos
@ -225,7 +225,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
delta types.StateDelta,
eventFilter *gomatrixserverlib.RoomEventFilter,
stateFilter *gomatrixserverlib.StateFilter,
res *types.Response,
req *types.SyncRequest,
) (types.StreamPosition, error) {
if delta.MembershipPos > 0 && delta.Membership == gomatrixserverlib.Leave {
// make sure we don't leak recent events after the leave event.
@ -290,8 +290,10 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
hasMembershipChange := false
for _, recentEvent := range recentStreamEvents {
if recentEvent.Type() == gomatrixserverlib.MRoomMember && recentEvent.StateKey() != nil {
if membership, _ := recentEvent.Membership(); membership == gomatrixserverlib.Join {
req.MembershipChanges[*recentEvent.StateKey()] = struct{}{}
}
hasMembershipChange = true
break
}
}
@ -318,9 +320,9 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync)
// If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited.
jr.Timeline.Limited = limited && len(events) == len(recentEvents)
jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync)
res.Rooms.Join[delta.RoomID] = jr
req.Response.Rooms.Join[delta.RoomID] = jr
case gomatrixserverlib.Peek:
jr := types.NewJoinResponse()
@ -329,7 +331,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync)
res.Rooms.Peek[delta.RoomID] = jr
req.Response.Rooms.Peek[delta.RoomID] = jr
case gomatrixserverlib.Leave:
fallthrough // transitions to leave are the same as ban
@ -342,7 +344,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
// didn't "remove" events, return that the response is limited.
lr.Timeline.Limited = limited && len(events) == len(recentEvents)
lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync)
res.Rooms.Leave[delta.RoomID] = lr
req.Response.Rooms.Leave[delta.RoomID] = lr
}
return latestPosition, nil

View file

@ -101,6 +101,7 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
Timeout: timeout, //
Rooms: make(map[string]string), // Populated by the PDU stream
WantFullState: wantFullState, //
MembershipChanges: make(map[string]struct{}), // Populated by the PDU stream
}, nil
}

View file

@ -4,9 +4,10 @@ import (
"context"
"time"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
userapi "github.com/matrix-org/dendrite/userapi/api"
)
type SyncRequest struct {
@ -22,6 +23,8 @@ type SyncRequest struct {
// Updated by the PDU stream.
Rooms map[string]string
// Updated by the PDU stream.
MembershipChanges map[string]struct{}
// Updated by the PDU stream.
IgnoredUsers IgnoredUsers
}

View file

@ -22,10 +22,6 @@ Forgotten room messages cannot be paginated
Local device key changes get to remote servers with correct prev_id
# Flakey
Local device key changes appear in /keys/changes
# we don't support groups
Remove group category
@ -39,12 +35,6 @@ Events in rooms with AS-hosted room aliases are sent to AS server
Inviting an AS-hosted user asks the AS server
Accesing an AS-hosted room alias asks the AS server
# Flakey, need additional investigation
Messages that notify from another user increment notification_count
Messages that highlight from another user increment unread highlight count
Notifications can be viewed with GET /notifications
# More flakey
Guest users can join guest_access rooms

View file

@ -746,3 +746,9 @@ Inbound federation can return missing events for joined visibility
outliers whose auth_events are in a different room are correctly rejected
Messages that notify from another user increment notification_count
Messages that highlight from another user increment unread highlight count
Newly joined room has correct timeline in incremental sync
When user joins a room the state is included in the next sync
When user joins a room the state is included in a gapped sync
Messages that notify from another user increment notification_count
Messages that highlight from another user increment unread highlight count
Notifications can be viewed with GET /notifications

View file

@ -96,7 +96,7 @@ type ClientUserAPI interface {
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
SetAvatarURL(ctx context.Context, req *PerformSetAvatarURLRequest, res *PerformSetAvatarURLResponse) error
SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *struct{}) error
SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *PerformUpdateDisplayNameResponse) error
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error
@ -579,7 +579,10 @@ type Notification struct {
type PerformSetAvatarURLRequest struct {
Localpart, AvatarURL string
}
type PerformSetAvatarURLResponse struct{}
type PerformSetAvatarURLResponse struct {
Profile *authtypes.Profile `json:"profile"`
Changed bool `json:"changed"`
}
type QueryNumericLocalpartResponse struct {
ID int64
@ -606,6 +609,11 @@ type PerformUpdateDisplayNameRequest struct {
Localpart, DisplayName string
}
type PerformUpdateDisplayNameResponse struct {
Profile *authtypes.Profile `json:"profile"`
Changed bool `json:"changed"`
}
type QueryLocalpartForThreePIDRequest struct {
ThreePID, Medium string
}

View file

@ -168,7 +168,7 @@ func (t *UserInternalAPITrace) QueryAccountAvailability(ctx context.Context, req
return err
}
func (t *UserInternalAPITrace) SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *struct{}) error {
func (t *UserInternalAPITrace) SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *PerformUpdateDisplayNameResponse) error {
err := t.Impl.SetDisplayName(ctx, req, res)
util.GetLogger(ctx).Infof("SetDisplayName req=%+v res=%+v", js(req), js(res))
return err

View file

@ -81,7 +81,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats
readPos := msg.Header.Get(jetstream.EventID)
evType := msg.Header.Get("type")
if readPos == "" || evType != "m.read" {
if readPos == "" || (evType != "m.read" && evType != "m.read.private") {
return true
}

View file

@ -10,19 +10,24 @@ import (
"github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/userapi/storage"
)
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{
db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, "", 4, 0, 0, "")
if err != nil {
t.Fatalf("failed to create new user db: %v", err)
}
return db, close
return db, func() {
close()
baseclose()
}
}
func mustCreateEvent(t *testing.T, content string) *gomatrixserverlib.HeaderedEvent {

View file

@ -170,7 +170,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
return nil
}
if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
return err
}
@ -813,7 +813,10 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush
}
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error {
return a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL)
profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL)
res.Profile = profile
res.Changed = changed
return err
}
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error {
@ -847,8 +850,11 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q
}
}
func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, _ *struct{}) error {
return a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName)
func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error {
profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName)
res.Profile = profile
res.Changed = changed
return err
}
func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {

View file

@ -388,7 +388,7 @@ func (h *httpUserInternalAPI) QueryAccountByPassword(
func (h *httpUserInternalAPI) SetDisplayName(
ctx context.Context,
request *api.PerformUpdateDisplayNameRequest,
response *struct{},
response *api.PerformUpdateDisplayNameResponse,
) error {
return httputil.CallInternalRPCAPI(
"SetDisplayName", h.apiURL+PerformSetDisplayNamePath,

View file

@ -29,8 +29,8 @@ import (
type Profile interface {
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
SetDisplayName(ctx context.Context, localpart string, displayName string) error
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, bool, error)
SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, bool, error)
}
type Account interface {

View file

@ -26,7 +26,7 @@ import (
const accountDataSchema = `
-- Stores data about accounts data.
CREATE TABLE IF NOT EXISTS account_data (
CREATE TABLE IF NOT EXISTS userapi_account_datas (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL,
-- The room ID for this data (empty string if not specific to a room)
@ -41,15 +41,15 @@ CREATE TABLE IF NOT EXISTS account_data (
`
const insertAccountDataSQL = `
INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content
`
const selectAccountDataSQL = "" +
"SELECT room_id, type, content FROM account_data WHERE localpart = $1"
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1"
const selectAccountDataByTypeSQL = "" +
"SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3"
type accountDataStatements struct {
insertAccountDataStmt *sql.Stmt

View file

@ -32,7 +32,7 @@ import (
const accountsSchema = `
-- Stores data about accounts.
CREATE TABLE IF NOT EXISTS account_accounts (
CREATE TABLE IF NOT EXISTS userapi_accounts (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY,
-- When this account was first created, as a unix timestamp (ms resolution).
@ -51,22 +51,22 @@ CREATE TABLE IF NOT EXISTS account_accounts (
`
const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
"INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2"
const deactivateAccountSQL = "" +
"UPDATE account_accounts SET is_deactivated = TRUE WHERE localpart = $1"
"UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1"
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1"
"SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1"
const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
const selectNewNumericLocalpartSQL = "" +
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM account_accounts WHERE localpart ~ '^[0-9]{1,}$'"
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$'"
type accountsStatements struct {
insertAccountStmt *sql.Stmt

View file

@ -7,7 +7,7 @@ import (
)
func UpIsActive(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS is_deactivated BOOLEAN DEFAULT FALSE;")
_, err := tx.ExecContext(ctx, "ALTER TABLE userapi_accounts ADD COLUMN IF NOT EXISTS is_deactivated BOOLEAN DEFAULT FALSE;")
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
@ -15,7 +15,7 @@ func UpIsActive(ctx context.Context, tx *sql.Tx) error {
}
func DownIsActive(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts DROP COLUMN is_deactivated;")
_, err := tx.ExecContext(ctx, "ALTER TABLE userapi_accounts DROP COLUMN is_deactivated;")
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}

View file

@ -8,9 +8,9 @@ import (
func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS last_seen_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM CURRENT_TIMESTAMP)*1000;
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS ip TEXT;
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`)
ALTER TABLE userapi_devices ADD COLUMN IF NOT EXISTS last_seen_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM CURRENT_TIMESTAMP)*1000;
ALTER TABLE userapi_devices ADD COLUMN IF NOT EXISTS ip TEXT;
ALTER TABLE userapi_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
@ -19,9 +19,9 @@ ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`)
func DownLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
ALTER TABLE device_devices DROP COLUMN last_seen_ts;
ALTER TABLE device_devices DROP COLUMN ip;
ALTER TABLE device_devices DROP COLUMN user_agent;`)
ALTER TABLE userapi_devices DROP COLUMN last_seen_ts;
ALTER TABLE userapi_devices DROP COLUMN ip;
ALTER TABLE userapi_devices DROP COLUMN user_agent;`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}

View file

@ -9,10 +9,10 @@ import (
func UpAddAccountType(ctx context.Context, tx *sql.Tx) error {
// initially set every account to useraccount, change appservice and guest accounts afterwards
// (user = 1, guest = 2, admin = 3, appservice = 4)
_, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS account_type SMALLINT NOT NULL DEFAULT 1;
UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> '';
UPDATE account_accounts SET account_type = 2 WHERE localpart ~ '^[0-9]+$';
ALTER TABLE account_accounts ALTER COLUMN account_type DROP DEFAULT;`,
_, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts ADD COLUMN IF NOT EXISTS account_type SMALLINT NOT NULL DEFAULT 1;
UPDATE userapi_accounts SET account_type = 4 WHERE appservice_id <> '';
UPDATE userapi_accounts SET account_type = 2 WHERE localpart ~ '^[0-9]+$';
ALTER TABLE userapi_accounts ALTER COLUMN account_type DROP DEFAULT;`,
)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
@ -21,7 +21,7 @@ ALTER TABLE account_accounts ALTER COLUMN account_type DROP DEFAULT;`,
}
func DownAddAccountType(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts DROP COLUMN account_type;")
_, err := tx.ExecContext(ctx, "ALTER TABLE userapi_accounts DROP COLUMN account_type;")
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}

View file

@ -8,7 +8,7 @@ import (
func UpNoGuests(ctx context.Context, tx *sql.Tx) error {
// AddAccountType introduced a bug where each user that had was registered as a regular user, but without user_id, became a guest.
_, err := tx.ExecContext(ctx, "UPDATE account_accounts SET account_type = 1 WHERE account_type = 2;")
_, err := tx.ExecContext(ctx, "UPDATE userapi_accounts SET account_type = 1 WHERE account_type = 2;")
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}

View file

@ -0,0 +1,102 @@
package deltas
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
)
var renameTableMappings = map[string]string{
"account_accounts": "userapi_accounts",
"account_data": "userapi_account_datas",
"device_devices": "userapi_devices",
"account_e2e_room_keys": "userapi_key_backups",
"account_e2e_room_keys_versions": "userapi_key_backup_versions",
"login_tokens": "userapi_login_tokens",
"open_id_tokens": "userapi_openid_tokens",
"account_profiles": "userapi_profiles",
"account_threepid": "userapi_threepids",
}
var renameSequenceMappings = map[string]string{
"device_session_id_seq": "userapi_device_session_id_seq",
"account_e2e_room_keys_versions_seq": "userapi_key_backup_versions_seq",
}
var renameIndicesMappings = map[string]string{
"device_localpart_id_idx": "userapi_device_localpart_id_idx",
"e2e_room_keys_idx": "userapi_key_backups_idx",
"e2e_room_keys_versions_idx": "userapi_key_backups_versions_idx",
"account_e2e_room_keys_versions_idx": "userapi_key_backup_versions_idx",
"login_tokens_expiration_idx": "userapi_login_tokens_expiration_idx",
"account_threepid_localpart": "userapi_threepid_idx",
}
// I know what you're thinking: you're wondering "why doesn't this use $1
// and pass variadic parameters to ExecContext?" — the answer is because
// PostgreSQL doesn't expect the table name to be specified as a substituted
// argument in that way so it results in a syntax error in the query.
func UpRenameTables(ctx context.Context, tx *sql.Tx) error {
for old, new := range renameTableMappings {
q := fmt.Sprintf(
"ALTER TABLE IF EXISTS %s RENAME TO %s;",
pq.QuoteIdentifier(old), pq.QuoteIdentifier(new),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("rename table %q to %q error: %w", old, new, err)
}
}
for old, new := range renameSequenceMappings {
q := fmt.Sprintf(
"ALTER SEQUENCE IF EXISTS %s RENAME TO %s;",
pq.QuoteIdentifier(old), pq.QuoteIdentifier(new),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("rename table %q to %q error: %w", old, new, err)
}
}
for old, new := range renameIndicesMappings {
q := fmt.Sprintf(
"ALTER INDEX IF EXISTS %s RENAME TO %s;",
pq.QuoteIdentifier(old), pq.QuoteIdentifier(new),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("rename table %q to %q error: %w", old, new, err)
}
}
return nil
}
func DownRenameTables(ctx context.Context, tx *sql.Tx) error {
for old, new := range renameTableMappings {
q := fmt.Sprintf(
"ALTER TABLE IF EXISTS %s RENAME TO %s;",
pq.QuoteIdentifier(new), pq.QuoteIdentifier(old),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("rename table %q to %q error: %w", new, old, err)
}
}
for old, new := range renameSequenceMappings {
q := fmt.Sprintf(
"ALTER SEQUENCE IF EXISTS %s RENAME TO %s;",
pq.QuoteIdentifier(new), pq.QuoteIdentifier(old),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("rename table %q to %q error: %w", new, old, err)
}
}
for old, new := range renameIndicesMappings {
q := fmt.Sprintf(
"ALTER INDEX IF EXISTS %s RENAME TO %s;",
pq.QuoteIdentifier(new), pq.QuoteIdentifier(old),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("rename table %q to %q error: %w", new, old, err)
}
}
return nil
}

View file

@ -31,10 +31,10 @@ import (
const devicesSchema = `
-- This sequence is used for automatic allocation of session_id.
CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
CREATE SEQUENCE IF NOT EXISTS userapi_device_session_id_seq START 1;
-- Stores data about devices.
CREATE TABLE IF NOT EXISTS device_devices (
CREATE TABLE IF NOT EXISTS userapi_devices (
-- The access token granted to this device. This has to be the primary key
-- so we can distinguish which device is making a given request.
access_token TEXT NOT NULL PRIMARY KEY,
@ -42,7 +42,7 @@ CREATE TABLE IF NOT EXISTS device_devices (
-- This can be used as a secure substitution of the access token in situations
-- where data is associated with access tokens (e.g. transaction storage),
-- so we don't have to store users' access tokens everywhere.
session_id BIGINT NOT NULL DEFAULT nextval('device_session_id_seq'),
session_id BIGINT NOT NULL DEFAULT nextval('userapi_device_session_id_seq'),
-- The device identifier. This only needs to uniquely identify a device for a given user, not globally.
-- access_tokens will be clobbered based on the device ID for a user.
device_id TEXT NOT NULL,
@ -65,39 +65,39 @@ CREATE TABLE IF NOT EXISTS device_devices (
);
-- Device IDs must be unique for a given user.
CREATE UNIQUE INDEX IF NOT EXISTS device_localpart_id_idx ON device_devices(localpart, device_id);
CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, device_id);
`
const insertDeviceSQL = "" +
"INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" +
"INSERT INTO userapi_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" +
" RETURNING session_id"
const selectDeviceByTokenSQL = "" +
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
"SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" +
"SELECT display_name, last_seen_ts, ip FROM device_devices WHERE localpart = $1 and device_id = $2"
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id = ANY($2)"
const selectDevicesByIDSQL = "" +
"SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
"SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
const updateDeviceLastSeen = "" +
"UPDATE device_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
type devicesStatements struct {
insertDeviceStmt *sql.Stmt

View file

@ -26,7 +26,7 @@ import (
)
const keyBackupTableSchema = `
CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
CREATE TABLE IF NOT EXISTS userapi_key_backups (
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
session_id TEXT NOT NULL,
@ -37,31 +37,31 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
is_verified BOOLEAN NOT NULL,
session_data TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version);
CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version);
CREATE UNIQUE INDEX IF NOT EXISTS userapi_key_backups_idx ON userapi_key_backups(user_id, room_id, session_id, version);
CREATE INDEX IF NOT EXISTS userapi_key_backups_versions_idx ON userapi_key_backups(user_id, version);
`
const insertBackupKeySQL = "" +
"INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " +
"INSERT INTO userapi_key_backups(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
const updateBackupKeySQL = "" +
"UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " +
"UPDATE userapi_key_backups SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " +
"WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8"
const countKeysSQL = "" +
"SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2"
"SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2"
const selectKeysSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2"
const selectKeysByRoomIDSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2 AND room_id = $3"
const selectKeysByRoomIDAndSessionIDSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4"
type keyBackupStatements struct {

View file

@ -26,40 +26,40 @@ import (
)
const keyBackupVersionTableSchema = `
CREATE SEQUENCE IF NOT EXISTS account_e2e_room_keys_versions_seq;
CREATE SEQUENCE IF NOT EXISTS userapi_key_backup_versions_seq;
-- the metadata for each generation of encrypted e2e session backups
CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions (
CREATE TABLE IF NOT EXISTS userapi_key_backup_versions (
user_id TEXT NOT NULL,
-- this means no 2 users will ever have the same version of e2e session backups which strictly
-- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1.
version BIGINT DEFAULT nextval('account_e2e_room_keys_versions_seq'),
version BIGINT DEFAULT nextval('userapi_key_backup_versions_seq'),
algorithm TEXT NOT NULL,
auth_data TEXT NOT NULL,
etag TEXT NOT NULL,
deleted SMALLINT DEFAULT 0 NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version);
CREATE UNIQUE INDEX IF NOT EXISTS userapi_key_backup_versions_idx ON userapi_key_backup_versions(user_id, version);
`
const insertKeyBackupSQL = "" +
"INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version"
"INSERT INTO userapi_key_backup_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version"
const updateKeyBackupAuthDataSQL = "" +
"UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3"
"UPDATE userapi_key_backup_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3"
const updateKeyBackupETagSQL = "" +
"UPDATE account_e2e_room_keys_versions SET etag = $1 WHERE user_id = $2 AND version = $3"
"UPDATE userapi_key_backup_versions SET etag = $1 WHERE user_id = $2 AND version = $3"
const deleteKeyBackupSQL = "" +
"UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2"
"UPDATE userapi_key_backup_versions SET deleted=1 WHERE user_id = $1 AND version = $2"
const selectKeyBackupSQL = "" +
"SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2"
"SELECT algorithm, auth_data, etag, deleted FROM userapi_key_backup_versions WHERE user_id = $1 AND version = $2"
const selectLatestVersionSQL = "" +
"SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1"
"SELECT MAX(version) FROM userapi_key_backup_versions WHERE user_id = $1"
type keyBackupVersionStatements struct {
insertKeyBackupStmt *sql.Stmt

View file

@ -26,7 +26,7 @@ import (
)
const loginTokenSchema = `
CREATE TABLE IF NOT EXISTS login_tokens (
CREATE TABLE IF NOT EXISTS userapi_login_tokens (
-- The random value of the token issued to a user
token TEXT NOT NULL PRIMARY KEY,
-- When the token expires
@ -37,17 +37,17 @@ CREATE TABLE IF NOT EXISTS login_tokens (
);
-- This index allows efficient garbage collection of expired tokens.
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at);
CREATE INDEX IF NOT EXISTS userapi_login_tokens_expiration_idx ON userapi_login_tokens(token_expires_at);
`
const insertLoginTokenSQL = "" +
"INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
"INSERT INTO userapi_login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
const deleteLoginTokenSQL = "" +
"DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"
"DELETE FROM userapi_login_tokens WHERE token = $1 OR token_expires_at <= $2"
const selectLoginTokenSQL = "" +
"SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"
"SELECT user_id FROM userapi_login_tokens WHERE token = $1 AND token_expires_at > $2"
type loginTokenStatements struct {
insertStmt *sql.Stmt
@ -78,7 +78,7 @@ func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx
// deleteByToken removes the named token.
//
// As a simple way to garbage-collect stale tokens, we also remove all expired tokens.
// The login_tokens_expiration_idx index should make that efficient.
// The userapi_login_tokens_expiration_idx index should make that efficient.
func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error {
stmt := sqlutil.TxStmt(txn, s.deleteStmt)
res, err := stmt.ExecContext(ctx, token, time.Now().UTC())

View file

@ -13,7 +13,7 @@ import (
const openIDTokenSchema = `
-- Stores data about openid tokens issued for accounts.
CREATE TABLE IF NOT EXISTS open_id_tokens (
CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
-- The value of the token issued to a user
token TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID for this account
@ -24,10 +24,10 @@ CREATE TABLE IF NOT EXISTS open_id_tokens (
`
const insertOpenIDTokenSQL = "" +
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
"INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
const selectOpenIDTokenSQL = "" +
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
"SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
type openIDTokenStatements struct {
insertTokenStmt *sql.Stmt

View file

@ -27,7 +27,7 @@ import (
const profilesSchema = `
-- Stores data about accounts profiles.
CREATE TABLE IF NOT EXISTS account_profiles (
CREATE TABLE IF NOT EXISTS userapi_profiles (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY,
-- The display name for this account
@ -38,19 +38,27 @@ CREATE TABLE IF NOT EXISTS account_profiles (
`
const insertProfileSQL = "" +
"INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
"INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
const selectProfileByLocalpartSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
const setAvatarURLSQL = "" +
"UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
"UPDATE userapi_profiles AS new" +
" SET avatar_url = $1" +
" FROM userapi_profiles AS old" +
" WHERE new.localpart = $2" +
" RETURNING new.display_name, old.avatar_url <> new.avatar_url"
const setDisplayNameSQL = "" +
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
"UPDATE userapi_profiles AS new" +
" SET display_name = $1" +
" FROM userapi_profiles AS old" +
" WHERE new.localpart = $2" +
" RETURNING new.avatar_url, old.display_name <> new.display_name"
const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
type profilesStatements struct {
serverNoticesLocalpart string
@ -100,16 +108,28 @@ func (s *profilesStatements) SelectProfileByLocalpart(
func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
) (err error) {
_, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart)
return
) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{
Localpart: localpart,
AvatarURL: avatarURL,
}
var changed bool
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName, &changed)
return profile, changed, err
}
func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
) (err error) {
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
return
) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{
Localpart: localpart,
DisplayName: displayName,
}
var changed bool
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL, &changed)
return profile, changed, err
}
func (s *profilesStatements) SelectProfilesBySearch(

View file

@ -45,7 +45,7 @@ CREATE INDEX IF NOT EXISTS userapi_daily_visits_localpart_timestamp_idx ON usera
const countUsersLastSeenAfterSQL = "" +
"SELECT COUNT(*) FROM (" +
" SELECT localpart FROM device_devices WHERE last_seen_ts > $1 " +
" SELECT localpart FROM userapi_devices WHERE last_seen_ts > $1 " +
" GROUP BY localpart" +
" ) u"
@ -62,7 +62,7 @@ R30Users counts the number of 30 day retained users, defined as:
const countR30UsersSQL = `
SELECT platform, COUNT(*) FROM (
SELECT users.localpart, platform, users.created_ts, MAX(uip.last_seen_ts)
FROM account_accounts users
FROM userapi_accounts users
INNER JOIN
(SELECT
localpart, last_seen_ts,
@ -75,7 +75,7 @@ SELECT platform, COUNT(*) FROM (
ELSE 'unknown'
END
AS platform
FROM device_devices
FROM userapi_devices
) uip
ON users.localpart = uip.localpart
AND users.account_type <> 4
@ -121,7 +121,7 @@ GROUP BY client_type
`
const countUserByAccountTypeSQL = `
SELECT COUNT(*) FROM account_accounts WHERE account_type = ANY($1)
SELECT COUNT(*) FROM userapi_accounts WHERE account_type = ANY($1)
`
// $1 = All non guest AccountType IDs
@ -134,7 +134,7 @@ SELECT user_type, COUNT(*) AS count FROM (
WHEN account_type = $2 AND appservice_id IS NULL THEN 'guest'
WHEN account_type = ANY($1) AND appservice_id IS NOT NULL THEN 'bridged'
END AS user_type
FROM account_accounts
FROM userapi_accounts
WHERE created_ts > $3
) AS t GROUP BY user_type
`
@ -143,14 +143,14 @@ SELECT user_type, COUNT(*) AS count FROM (
const updateUserDailyVisitsSQL = `
INSERT INTO userapi_daily_visits(localpart, device_id, timestamp, user_agent)
SELECT u.localpart, u.device_id, $1, MAX(u.user_agent)
FROM device_devices AS u
FROM userapi_devices AS u
LEFT JOIN (
SELECT localpart, device_id, timestamp FROM userapi_daily_visits
WHERE timestamp = $1
) udv
ON u.localpart = udv.localpart AND u.device_id = udv.device_id
INNER JOIN device_devices d ON d.localpart = u.localpart
INNER JOIN account_accounts a ON a.localpart = u.localpart
INNER JOIN userapi_devices d ON d.localpart = u.localpart
INNER JOIN userapi_accounts a ON a.localpart = u.localpart
WHERE $2 <= d.last_seen_ts AND d.last_seen_ts < $3
AND a.account_type in (1, 3)
GROUP BY u.localpart, u.device_id

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/userapi/storage/shared"
// Import the postgres database driver.
@ -36,6 +37,16 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
return nil, err
}
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "userapi: rename tables",
Up: deltas.UpRenameTables,
Down: deltas.DownRenameTables,
})
if err = m.Up(base.Context()); err != nil {
return nil, err
}
accountDataTable, err := NewPostgresAccountDataTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)

View file

@ -26,7 +26,7 @@ import (
const threepidSchema = `
-- Stores data about third party identifiers
CREATE TABLE IF NOT EXISTS account_threepid (
CREATE TABLE IF NOT EXISTS userapi_threepids (
-- The third party identifier
threepid TEXT NOT NULL,
-- The 3PID medium
@ -37,20 +37,20 @@ CREATE TABLE IF NOT EXISTS account_threepid (
PRIMARY KEY(threepid, medium)
);
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart);
CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart);
`
const selectLocalpartForThreePIDSQL = "" +
"SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
"SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
const selectThreePIDsForLocalpartSQL = "" +
"SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1"
const insertThreePIDSQL = "" +
"INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)"
"INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)"
const deleteThreePIDSQL = "" +
"DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
"DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
type threepidStatements struct {
selectLocalpartForThreePIDStmt *sql.Stmt

View file

@ -96,20 +96,24 @@ func (d *Database) GetProfileByLocalpart(
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
) (profile *authtypes.Profile, changed bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
return err
})
return
}
// SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
) (profile *authtypes.Profile, changed bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
return err
})
return
}
// SetPassword sets the account password to the given hash.

View file

@ -25,7 +25,7 @@ import (
const accountDataSchema = `
-- Stores data about accounts data.
CREATE TABLE IF NOT EXISTS account_data (
CREATE TABLE IF NOT EXISTS userapi_account_datas (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL,
-- The room ID for this data (empty string if not specific to a room)
@ -40,15 +40,15 @@ CREATE TABLE IF NOT EXISTS account_data (
`
const insertAccountDataSQL = `
INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
`
const selectAccountDataSQL = "" +
"SELECT room_id, type, content FROM account_data WHERE localpart = $1"
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1"
const selectAccountDataByTypeSQL = "" +
"SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3"
type accountDataStatements struct {
db *sql.DB

View file

@ -32,7 +32,7 @@ import (
const accountsSchema = `
-- Stores data about accounts.
CREATE TABLE IF NOT EXISTS account_accounts (
CREATE TABLE IF NOT EXISTS userapi_accounts (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY,
-- When this account was first created, as a unix timestamp (ms resolution).
@ -51,22 +51,22 @@ CREATE TABLE IF NOT EXISTS account_accounts (
`
const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
"INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2"
const deactivateAccountSQL = "" +
"UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1"
"UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1"
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1"
"SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1"
const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = 0"
const selectNewNumericLocalpartSQL = "" +
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM account_accounts WHERE CAST(localpart AS INT) <> 0"
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0"
type accountsStatements struct {
db *sql.DB

View file

@ -8,8 +8,8 @@ import (
func UpIsActive(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
ALTER TABLE account_accounts RENAME TO account_accounts_tmp;
CREATE TABLE account_accounts (
ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
CREATE TABLE userapi_accounts (
localpart TEXT NOT NULL PRIMARY KEY,
created_ts BIGINT NOT NULL,
password_hash TEXT,
@ -17,13 +17,13 @@ CREATE TABLE account_accounts (
is_deactivated BOOLEAN DEFAULT 0
);
INSERT
INTO account_accounts (
INTO userapi_accounts (
localpart, created_ts, password_hash, appservice_id
) SELECT
localpart, created_ts, password_hash, appservice_id
FROM account_accounts_tmp
FROM userapi_accounts_tmp
;
DROP TABLE account_accounts_tmp;`)
DROP TABLE userapi_accounts_tmp;`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
@ -32,21 +32,21 @@ DROP TABLE account_accounts_tmp;`)
func DownIsActive(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
ALTER TABLE account_accounts RENAME TO account_accounts_tmp;
CREATE TABLE account_accounts (
ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
CREATE TABLE userapi_accounts (
localpart TEXT NOT NULL PRIMARY KEY,
created_ts BIGINT NOT NULL,
password_hash TEXT,
appservice_id TEXT
);
INSERT
INTO account_accounts (
INTO userapi_accounts (
localpart, created_ts, password_hash, appservice_id
) SELECT
localpart, created_ts, password_hash, appservice_id
FROM account_accounts_tmp
FROM userapi_accounts_tmp
;
DROP TABLE account_accounts_tmp;`)
DROP TABLE userapi_accounts_tmp;`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}

View file

@ -8,8 +8,8 @@ import (
func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
ALTER TABLE device_devices RENAME TO device_devices_tmp;
CREATE TABLE device_devices (
ALTER TABLE userapi_devices RENAME TO userapi_devices_tmp;
CREATE TABLE userapi_devices (
access_token TEXT PRIMARY KEY,
session_id INTEGER,
device_id TEXT ,
@ -22,12 +22,12 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
UNIQUE (localpart, device_id)
);
INSERT
INTO device_devices (
INTO userapi_devices (
access_token, session_id, device_id, localpart, created_ts, display_name, last_seen_ts, ip, user_agent
) SELECT
access_token, session_id, device_id, localpart, created_ts, display_name, created_ts, '', ''
FROM device_devices_tmp;
DROP TABLE device_devices_tmp;`)
FROM userapi_devices_tmp;
DROP TABLE userapi_devices_tmp;`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
@ -36,8 +36,8 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
func DownLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
ALTER TABLE device_devices RENAME TO device_devices_tmp;
CREATE TABLE IF NOT EXISTS device_devices (
ALTER TABLE userapi_devices RENAME TO userapi_devices_tmp;
CREATE TABLE IF NOT EXISTS userapi_devices (
access_token TEXT PRIMARY KEY,
session_id INTEGER,
device_id TEXT ,
@ -47,12 +47,12 @@ CREATE TABLE IF NOT EXISTS device_devices (
UNIQUE (localpart, device_id)
);
INSERT
INTO device_devices (
INTO userapi_devices (
access_token, session_id, device_id, localpart, created_ts, display_name
) SELECT
access_token, session_id, device_id, localpart, created_ts, display_name
FROM device_devices_tmp;
DROP TABLE device_devices_tmp;`)
FROM userapi_devices_tmp;
DROP TABLE userapi_devices_tmp;`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}

View file

@ -9,8 +9,8 @@ import (
func UpAddAccountType(ctx context.Context, tx *sql.Tx) error {
// initially set every account to useraccount, change appservice and guest accounts afterwards
// (user = 1, guest = 2, admin = 3, appservice = 4)
_, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts RENAME TO account_accounts_tmp;
CREATE TABLE account_accounts (
_, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
CREATE TABLE userapi_accounts (
localpart TEXT NOT NULL PRIMARY KEY,
created_ts BIGINT NOT NULL,
password_hash TEXT,
@ -19,15 +19,15 @@ CREATE TABLE account_accounts (
account_type INTEGER NOT NULL
);
INSERT
INTO account_accounts (
INTO userapi_accounts (
localpart, created_ts, password_hash, appservice_id, account_type
) SELECT
localpart, created_ts, password_hash, appservice_id, 1
FROM account_accounts_tmp
FROM userapi_accounts_tmp
;
UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> '';
UPDATE account_accounts SET account_type = 2 WHERE localpart GLOB '[0-9]*';
DROP TABLE account_accounts_tmp;`)
UPDATE userapi_accounts SET account_type = 4 WHERE appservice_id <> '';
UPDATE userapi_accounts SET account_type = 2 WHERE localpart GLOB '[0-9]*';
DROP TABLE userapi_accounts_tmp;`)
if err != nil {
return fmt.Errorf("failed to add column: %w", err)
}
@ -35,7 +35,7 @@ DROP TABLE account_accounts_tmp;`)
}
func DownAddAccountType(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts DROP COLUMN account_type;`)
_, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts DROP COLUMN account_type;`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}

View file

@ -0,0 +1,109 @@
package deltas
import (
"context"
"database/sql"
"fmt"
"strings"
)
var renameTableMappings = map[string]string{
"account_accounts": "userapi_accounts",
"account_data": "userapi_account_datas",
"device_devices": "userapi_devices",
"account_e2e_room_keys": "userapi_key_backups",
"account_e2e_room_keys_versions": "userapi_key_backup_versions",
"login_tokens": "userapi_login_tokens",
"open_id_tokens": "userapi_openid_tokens",
"account_profiles": "userapi_profiles",
"account_threepid": "userapi_threepids",
}
var renameIndicesMappings = map[string]string{
"device_localpart_id_idx": "userapi_device_localpart_id_idx",
"e2e_room_keys_idx": "userapi_key_backups_idx",
"e2e_room_keys_versions_idx": "userapi_key_backups_versions_idx",
"account_e2e_room_keys_versions_idx": "userapi_key_backup_versions_idx",
"login_tokens_expiration_idx": "userapi_login_tokens_expiration_idx",
"account_threepid_localpart": "userapi_threepid_idx",
}
func UpRenameTables(ctx context.Context, tx *sql.Tx) error {
for old, new := range renameTableMappings {
// SQLite has no "IF EXISTS" so check if the table exists.
var name string
if err := tx.QueryRowContext(
ctx, "SELECT name FROM sqlite_schema WHERE type = 'table' AND name = $1;", old,
).Scan(&name); err != nil {
if err == sql.ErrNoRows {
continue
}
return err
}
q := fmt.Sprintf(
"ALTER TABLE %s RENAME TO %s;", old, new,
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("rename table %q to %q error: %w", old, new, err)
}
}
for old, new := range renameIndicesMappings {
var query string
if err := tx.QueryRowContext(
ctx, "SELECT sql FROM sqlite_schema WHERE type = 'index' AND name = $1;", old,
).Scan(&query); err != nil {
if err == sql.ErrNoRows {
continue
}
return err
}
query = strings.Replace(query, old, new, 1)
if _, err := tx.ExecContext(ctx, fmt.Sprintf("DROP INDEX %s;", old)); err != nil {
return fmt.Errorf("drop index %q to %q error: %w", old, new, err)
}
if _, err := tx.ExecContext(ctx, query); err != nil {
return fmt.Errorf("recreate index %q to %q error: %w", old, new, err)
}
}
return nil
}
func DownRenameTables(ctx context.Context, tx *sql.Tx) error {
for old, new := range renameTableMappings {
// SQLite has no "IF EXISTS" so check if the table exists.
var name string
if err := tx.QueryRowContext(
ctx, "SELECT name FROM sqlite_schema WHERE type = 'table' AND name = $1;", new,
).Scan(&name); err != nil {
if err == sql.ErrNoRows {
continue
}
return err
}
q := fmt.Sprintf(
"ALTER TABLE %s RENAME TO %s;", new, old,
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("rename table %q to %q error: %w", new, old, err)
}
}
for old, new := range renameIndicesMappings {
var query string
if err := tx.QueryRowContext(
ctx, "SELECT sql FROM sqlite_schema WHERE type = 'index' AND name = $1;", new,
).Scan(&query); err != nil {
if err == sql.ErrNoRows {
continue
}
return err
}
query = strings.Replace(query, new, old, 1)
if _, err := tx.ExecContext(ctx, fmt.Sprintf("DROP INDEX %s;", new)); err != nil {
return fmt.Errorf("drop index %q to %q error: %w", new, old, err)
}
if _, err := tx.ExecContext(ctx, query); err != nil {
return fmt.Errorf("recreate index %q to %q error: %w", new, old, err)
}
}
return nil
}

View file

@ -35,7 +35,7 @@ const devicesSchema = `
-- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
-- Stores data about devices.
CREATE TABLE IF NOT EXISTS device_devices (
CREATE TABLE IF NOT EXISTS userapi_devices (
access_token TEXT PRIMARY KEY,
session_id INTEGER,
device_id TEXT ,
@ -51,38 +51,38 @@ CREATE TABLE IF NOT EXISTS device_devices (
`
const insertDeviceSQL = "" +
"INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
"INSERT INTO userapi_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
const selectDevicesCountSQL = "" +
"SELECT COUNT(access_token) FROM device_devices"
"SELECT COUNT(access_token) FROM userapi_devices"
const selectDeviceByTokenSQL = "" +
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
"SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" +
"SELECT display_name, last_seen_ts, ip FROM device_devices WHERE localpart = $1 and device_id = $2"
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id IN ($2)"
const selectDevicesByIDSQL = "" +
"SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
"SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
const updateDeviceLastSeen = "" +
"UPDATE device_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
type devicesStatements struct {
db *sql.DB

View file

@ -26,7 +26,7 @@ import (
)
const keyBackupTableSchema = `
CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
CREATE TABLE IF NOT EXISTS userapi_key_backups (
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
session_id TEXT NOT NULL,
@ -37,31 +37,31 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
is_verified BOOLEAN NOT NULL,
session_data TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version);
CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version);
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON userapi_key_backups(user_id, room_id, session_id, version);
CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON userapi_key_backups(user_id, version);
`
const insertBackupKeySQL = "" +
"INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " +
"INSERT INTO userapi_key_backups(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
const updateBackupKeySQL = "" +
"UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " +
"UPDATE userapi_key_backups SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " +
"WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8"
const countKeysSQL = "" +
"SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2"
"SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2"
const selectKeysSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2"
const selectKeysByRoomIDSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2 AND room_id = $3"
const selectKeysByRoomIDAndSessionIDSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4"
type keyBackupStatements struct {

View file

@ -27,7 +27,7 @@ import (
const keyBackupVersionTableSchema = `
-- the metadata for each generation of encrypted e2e session backups
CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions (
CREATE TABLE IF NOT EXISTS userapi_key_backup_versions (
user_id TEXT NOT NULL,
-- this means no 2 users will ever have the same version of e2e session backups which strictly
-- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1.
@ -38,26 +38,26 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions (
deleted INTEGER DEFAULT 0 NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version);
CREATE UNIQUE INDEX IF NOT EXISTS userapi_key_backup_versions_idx ON userapi_key_backup_versions(user_id, version);
`
const insertKeyBackupSQL = "" +
"INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version"
"INSERT INTO userapi_key_backup_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version"
const updateKeyBackupAuthDataSQL = "" +
"UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3"
"UPDATE userapi_key_backup_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3"
const updateKeyBackupETagSQL = "" +
"UPDATE account_e2e_room_keys_versions SET etag = $1 WHERE user_id = $2 AND version = $3"
"UPDATE userapi_key_backup_versions SET etag = $1 WHERE user_id = $2 AND version = $3"
const deleteKeyBackupSQL = "" +
"UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2"
"UPDATE userapi_key_backup_versions SET deleted=1 WHERE user_id = $1 AND version = $2"
const selectKeyBackupSQL = "" +
"SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2"
"SELECT algorithm, auth_data, etag, deleted FROM userapi_key_backup_versions WHERE user_id = $1 AND version = $2"
const selectLatestVersionSQL = "" +
"SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1"
"SELECT MAX(version) FROM userapi_key_backup_versions WHERE user_id = $1"
type keyBackupVersionStatements struct {
insertKeyBackupStmt *sql.Stmt

View file

@ -32,7 +32,7 @@ type loginTokenStatements struct {
}
const loginTokenSchema = `
CREATE TABLE IF NOT EXISTS login_tokens (
CREATE TABLE IF NOT EXISTS userapi_login_tokens (
-- The random value of the token issued to a user
token TEXT NOT NULL PRIMARY KEY,
-- When the token expires
@ -43,17 +43,17 @@ CREATE TABLE IF NOT EXISTS login_tokens (
);
-- This index allows efficient garbage collection of expired tokens.
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at);
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON userapi_login_tokens(token_expires_at);
`
const insertLoginTokenSQL = "" +
"INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
"INSERT INTO userapi_login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
const deleteLoginTokenSQL = "" +
"DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"
"DELETE FROM userapi_login_tokens WHERE token = $1 OR token_expires_at <= $2"
const selectLoginTokenSQL = "" +
"SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"
"SELECT user_id FROM userapi_login_tokens WHERE token = $1 AND token_expires_at > $2"
func NewSQLiteLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) {
s := &loginTokenStatements{}
@ -78,7 +78,7 @@ func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx
// deleteByToken removes the named token.
//
// As a simple way to garbage-collect stale tokens, we also remove all expired tokens.
// The login_tokens_expiration_idx index should make that efficient.
// The userapi_login_tokens_expiration_idx index should make that efficient.
func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error {
stmt := sqlutil.TxStmt(txn, s.deleteStmt)
res, err := stmt.ExecContext(ctx, token, time.Now().UTC())

View file

@ -13,7 +13,7 @@ import (
const openIDTokenSchema = `
-- Stores data about accounts.
CREATE TABLE IF NOT EXISTS open_id_tokens (
CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
-- The value of the token issued to a user
token TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID for this account
@ -24,10 +24,10 @@ CREATE TABLE IF NOT EXISTS open_id_tokens (
`
const insertOpenIDTokenSQL = "" +
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
"INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
const selectOpenIDTokenSQL = "" +
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
"SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
type openIDTokenStatements struct {
db *sql.DB

View file

@ -27,7 +27,7 @@ import (
const profilesSchema = `
-- Stores data about accounts profiles.
CREATE TABLE IF NOT EXISTS account_profiles (
CREATE TABLE IF NOT EXISTS userapi_profiles (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY,
-- The display name for this account
@ -38,19 +38,21 @@ CREATE TABLE IF NOT EXISTS account_profiles (
`
const insertProfileSQL = "" +
"INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
"INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
const selectProfileByLocalpartSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
const setAvatarURLSQL = "" +
"UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" +
" RETURNING display_name"
const setDisplayNameSQL = "" +
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" +
" RETURNING avatar_url"
const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
type profilesStatements struct {
db *sql.DB
@ -102,18 +104,40 @@ func (s *profilesStatements) SelectProfileByLocalpart(
func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
) (err error) {
) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{
Localpart: localpart,
AvatarURL: avatarURL,
}
old, err := s.SelectProfileByLocalpart(ctx, localpart)
if err != nil {
return old, false, err
}
if old.AvatarURL == avatarURL {
return old, false, nil
}
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
_, err = stmt.ExecContext(ctx, avatarURL, localpart)
return
err = stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName)
return profile, true, err
}
func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
) (err error) {
) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{
Localpart: localpart,
DisplayName: displayName,
}
old, err := s.SelectProfileByLocalpart(ctx, localpart)
if err != nil {
return old, false, err
}
if old.DisplayName == displayName {
return old, false, nil
}
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
_, err = stmt.ExecContext(ctx, displayName, localpart)
return
err = stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL)
return profile, true, err
}
func (s *profilesStatements) SelectProfilesBySearch(

View file

@ -46,7 +46,7 @@ CREATE INDEX IF NOT EXISTS userapi_daily_visits_localpart_timestamp_idx ON usera
const countUsersLastSeenAfterSQL = "" +
"SELECT COUNT(*) FROM (" +
" SELECT localpart FROM device_devices WHERE last_seen_ts > $1 " +
" SELECT localpart FROM userapi_devices WHERE last_seen_ts > $1 " +
" GROUP BY localpart" +
" ) u"
@ -63,7 +63,7 @@ R30Users counts the number of 30 day retained users, defined as:
const countR30UsersSQL = `
SELECT platform, COUNT(*) FROM (
SELECT users.localpart, platform, users.created_ts, MAX(uip.last_seen_ts)
FROM account_accounts users
FROM userapi_accounts users
INNER JOIN
(SELECT
localpart, last_seen_ts,
@ -76,7 +76,7 @@ SELECT platform, COUNT(*) FROM (
ELSE 'unknown'
END
AS platform
FROM device_devices
FROM userapi_devices
) uip
ON users.localpart = uip.localpart
AND users.account_type <> 4
@ -126,7 +126,7 @@ GROUP BY client_type
`
const countUserByAccountTypeSQL = `
SELECT COUNT(*) FROM account_accounts WHERE account_type IN ($1)
SELECT COUNT(*) FROM userapi_accounts WHERE account_type IN ($1)
`
// $1 = Guest AccountType
@ -139,7 +139,7 @@ SELECT user_type, COUNT(*) AS count FROM (
WHEN account_type = $4 AND appservice_id IS NULL THEN 'guest'
WHEN account_type IN ($5) AND appservice_id IS NOT NULL THEN 'bridged'
END AS user_type
FROM account_accounts
FROM userapi_accounts
WHERE created_ts > $8
) AS t GROUP BY user_type
`
@ -148,14 +148,14 @@ SELECT user_type, COUNT(*) AS count FROM (
const updateUserDailyVisitsSQL = `
INSERT INTO userapi_daily_visits(localpart, device_id, timestamp, user_agent)
SELECT u.localpart, u.device_id, $1, MAX(u.user_agent)
FROM device_devices AS u
FROM userapi_devices AS u
LEFT JOIN (
SELECT localpart, device_id, timestamp FROM userapi_daily_visits
WHERE timestamp = $1
) udv
ON u.localpart = udv.localpart AND u.device_id = udv.device_id
INNER JOIN device_devices d ON d.localpart = u.localpart
INNER JOIN account_accounts a ON a.localpart = u.localpart
INNER JOIN userapi_devices d ON d.localpart = u.localpart
INNER JOIN userapi_accounts a ON a.localpart = u.localpart
WHERE $2 <= d.last_seen_ts AND d.last_seen_ts < $3
AND a.account_type in (1, 3)
GROUP BY u.localpart, u.device_id

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/shared"
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
)
// NewDatabase creates a new accounts and profiles database
@ -34,6 +35,16 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
return nil, err
}
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "userapi: rename tables",
Up: deltas.UpRenameTables,
Down: deltas.DownRenameTables,
})
if err = m.Up(base.Context()); err != nil {
return nil, err
}
accountDataTable, err := NewSQLiteAccountDataTable(db)
if err != nil {
return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)

View file

@ -27,7 +27,7 @@ import (
const threepidSchema = `
-- Stores data about third party identifiers
CREATE TABLE IF NOT EXISTS account_threepid (
CREATE TABLE IF NOT EXISTS userapi_threepids (
-- The third party identifier
threepid TEXT NOT NULL,
-- The 3PID medium
@ -38,20 +38,20 @@ CREATE TABLE IF NOT EXISTS account_threepid (
PRIMARY KEY(threepid, medium)
);
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart);
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart);
`
const selectLocalpartForThreePIDSQL = "" +
"SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
"SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
const selectThreePIDsForLocalpartSQL = "" +
"SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1"
const insertThreePIDSQL = "" +
"INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)"
"INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)"
const deleteThreePIDSQL = "" +
"DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
"DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
type threepidStatements struct {
db *sql.DB

View file

@ -16,6 +16,7 @@ import (
"github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/tables"
@ -29,14 +30,18 @@ var (
)
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{
db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server")
if err != nil {
t.Fatalf("NewUserAPIDatabase returned %s", err)
}
return db, close
return db, func() {
close()
baseclose()
}
}
// Tests storing and getting account data
@ -377,15 +382,23 @@ func Test_Profile(t *testing.T) {
// set avatar & displayname
wantProfile.DisplayName = "Alice"
wantProfile.AvatarURL = "mxc://aliceAvatar"
err = db.SetDisplayName(ctx, aliceLocalpart, "Alice")
assert.NoError(t, err, "unable to set displayname")
err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
assert.NoError(t, err, "unable to set avatar url")
// verify profile
gotProfile, err = db.GetProfileByLocalpart(ctx, aliceLocalpart)
assert.NoError(t, err, "unable to get profile by localpart")
gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, "Alice")
assert.Equal(t, wantProfile, gotProfile)
assert.NoError(t, err, "unable to set displayname")
assert.True(t, changed)
wantProfile.AvatarURL = "mxc://aliceAvatar"
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
assert.NoError(t, err, "unable to set avatar url")
assert.Equal(t, wantProfile, gotProfile)
assert.True(t, changed)
// Setting the same avatar again doesn't change anything
wantProfile.AvatarURL = "mxc://aliceAvatar"
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
assert.NoError(t, err, "unable to set avatar url")
assert.Equal(t, wantProfile, gotProfile)
assert.False(t, changed)
// search profiles
searchRes, err := db.SearchProfiles(ctx, "Alice", 2)

View file

@ -84,8 +84,8 @@ type OpenIDTable interface {
type ProfileTable interface {
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error
SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (err error)
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (err error)
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, bool, error)
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, bool, error)
SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
}

View file

@ -106,7 +106,7 @@ func mustUpdateDeviceLastSeen(
timestamp time.Time,
) {
t.Helper()
_, err := db.ExecContext(ctx, "UPDATE device_devices SET last_seen_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart)
_, err := db.ExecContext(ctx, "UPDATE userapi_devices SET last_seen_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart)
if err != nil {
t.Fatalf("unable to update device last seen")
}
@ -119,7 +119,7 @@ func mustUserUpdateRegistered(
localpart string,
timestamp time.Time,
) {
_, err := db.ExecContext(ctx, "UPDATE account_accounts SET created_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart)
_, err := db.ExecContext(ctx, "UPDATE userapi_accounts SET created_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart)
if err != nil {
t.Fatalf("unable to update device last seen")
}

View file

@ -23,13 +23,15 @@ import (
"time"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/dendrite/userapi/inthttp"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/dendrite/userapi/inthttp"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/internal"
@ -48,9 +50,9 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap
if opts.loginTokenLifetime == 0 {
opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond
}
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
connStr, close := test.PrepareDBConnectionString(t, dbType)
accountDB, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{
accountDB, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "")
if err != nil {
@ -66,7 +68,10 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap
return &internal.UserInternalAPI{
DB: accountDB,
ServerName: cfg.Matrix.ServerName,
}, accountDB, close
}, accountDB, func() {
close()
baseclose()
}
}
func TestQueryProfile(t *testing.T) {
@ -79,10 +84,10 @@ func TestQueryProfile(t *testing.T) {
if err != nil {
t.Fatalf("failed to make account: %s", err)
}
if err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil {
if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil {
t.Fatalf("failed to set avatar url: %s", err)
}
if err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil {
if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil {
t.Fatalf("failed to set display name: %s", err)
}