Merge branch 'main' into neilalexander/localservernames

This commit is contained in:
Neil Alexander 2022-10-24 11:29:54 +01:00
commit 107135f573
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
78 changed files with 2574 additions and 727 deletions

View file

@ -1,5 +1,23 @@
# Changelog
## Dendrite 0.10.4 (2022-10-21)
### Features
* Various tables belonging to the user API will be renamed so that they are namespaced with the `userapi_` prefix
* Note that, after upgrading to this version, you should not revert to an older version of Dendrite as the database changes **will not** be reverted automatically
* The backoff and retry behaviour in the federation API has been refactored and improved
### Fixes
* Private read receipt support is now advertised in the client `/versions` endpoint
* Private read receipts will now clear notification counts properly
* A bug where a false `leave` membership transition was inserted into the timeline after accepting an invite has been fixed
* Some panics caused by concurrent map writes in the key server have been fixed
* The sync API now calculates membership transitions from state deltas more accurately
* Transaction IDs are now scoped to endpoints, which should fix some bugs where transaction ID reuse could cause nonsensical cached responses from some endpoints
* The length of the `type`, `sender`, `state_key` and `room_id` fields in events are now verified by number of bytes rather than codepoints after a spec clarification, reverting a change made in Dendrite 0.9.6
## Dendrite 0.10.3 (2022-10-14)
### Features

View file

@ -99,7 +99,11 @@ func (r *queryKeysRequest) GetTimeout() time.Duration {
if r.Timeout == 0 {
return 10 * time.Second
}
return time.Duration(r.Timeout) * time.Millisecond
timeout := time.Duration(r.Timeout) * time.Millisecond
if timeout > time.Second*20 {
timeout = time.Second * 20
}
return timeout
}
func QueryKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Device) util.JSONResponse {

View file

@ -19,6 +19,8 @@ import (
"net/http"
"time"
"github.com/matrix-org/gomatrixserverlib"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
@ -27,7 +29,6 @@ import (
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/util"
@ -126,21 +127,6 @@ func SetAvatarURL(
}
}
res := &userapi.QueryProfileResponse{}
err = profileAPI.QueryProfile(req.Context(), &userapi.QueryProfileRequest{
UserID: userID,
}, res)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.QueryProfile failed")
return jsonerror.InternalServerError()
}
oldProfile := &authtypes.Profile{
Localpart: localpart,
ServerName: string(domain),
DisplayName: res.DisplayName,
AvatarURL: res.AvatarURL,
}
setRes := &userapi.PerformSetAvatarURLResponse{}
if err = profileAPI.SetAvatarURL(req.Context(), &userapi.PerformSetAvatarURLRequest{
Localpart: localpart,
@ -150,42 +136,17 @@ func SetAvatarURL(
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetAvatarURL failed")
return jsonerror.InternalServerError()
}
var roomsRes api.QueryRoomsForUserResponse
err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{
UserID: device.UserID,
WantMembership: "join",
}, &roomsRes)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed")
return jsonerror.InternalServerError()
}
newProfile := authtypes.Profile{
Localpart: localpart,
ServerName: string(domain),
DisplayName: oldProfile.DisplayName,
AvatarURL: r.AvatarURL,
}
events, err := buildMembershipEvents(
req.Context(), roomsRes.RoomIDs, newProfile, userID, cfg, evTime, rsAPI,
)
switch e := err.(type) {
case nil:
case gomatrixserverlib.BadJSONError:
// No need to build new membership events, since nothing changed
if !setRes.Changed {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(e.Error()),
Code: http.StatusOK,
JSON: struct{}{},
}
default:
util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvents failed")
return jsonerror.InternalServerError()
}
if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, domain, domain, nil, true); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
return jsonerror.InternalServerError()
response, err := updateProfile(req.Context(), rsAPI, device, setRes.Profile, userID, cfg, evTime)
if err != nil {
return response
}
return util.JSONResponse{
@ -258,50 +219,58 @@ func SetDisplayName(
}
}
pRes := &userapi.QueryProfileResponse{}
err = profileAPI.QueryProfile(req.Context(), &userapi.QueryProfileRequest{
UserID: userID,
}, pRes)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.QueryProfile failed")
return jsonerror.InternalServerError()
}
oldProfile := &authtypes.Profile{
Localpart: localpart,
ServerName: string(domain),
DisplayName: pRes.DisplayName,
AvatarURL: pRes.AvatarURL,
}
profileRes := &userapi.PerformUpdateDisplayNameResponse{}
err = profileAPI.SetDisplayName(req.Context(), &userapi.PerformUpdateDisplayNameRequest{
Localpart: localpart,
ServerName: domain,
DisplayName: r.DisplayName,
}, &struct{}{})
}, profileRes)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed")
return jsonerror.InternalServerError()
}
// No need to build new membership events, since nothing changed
if !profileRes.Changed {
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
response, err := updateProfile(req.Context(), rsAPI, device, profileRes.Profile, userID, cfg, evTime)
if err != nil {
return response
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
func updateProfile(
ctx context.Context, rsAPI api.ClientRoomserverAPI, device *userapi.Device,
profile *authtypes.Profile,
userID string, cfg *config.ClientAPI, evTime time.Time,
) (util.JSONResponse, error) {
var res api.QueryRoomsForUserResponse
err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{
err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
UserID: device.UserID,
WantMembership: "join",
}, &res)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed")
return jsonerror.InternalServerError()
util.GetLogger(ctx).WithError(err).Error("QueryRoomsForUser failed")
return jsonerror.InternalServerError(), err
}
newProfile := authtypes.Profile{
Localpart: localpart,
ServerName: string(domain),
DisplayName: r.DisplayName,
AvatarURL: oldProfile.AvatarURL,
_, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError(), err
}
events, err := buildMembershipEvents(
req.Context(), res.RoomIDs, newProfile, userID, cfg, evTime, rsAPI,
ctx, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI,
)
switch e := err.(type) {
case nil:
@ -309,21 +278,17 @@ func SetDisplayName(
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(e.Error()),
}
}, e
default:
util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvents failed")
return jsonerror.InternalServerError()
util.GetLogger(ctx).WithError(err).Error("buildMembershipEvents failed")
return jsonerror.InternalServerError(), e
}
if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, domain, domain, nil, true); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
return jsonerror.InternalServerError()
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, domain, domain, nil, true); err != nil {
util.GetLogger(ctx).WithError(err).Error("SendEvents failed")
return jsonerror.InternalServerError(), err
}
return util.JSONResponse{}, nil
}
// getProfile gets the full profile of a user by querying the database or a

View file

@ -178,7 +178,7 @@ func Setup(
// server notifications
if cfg.Matrix.ServerNotices.Enabled {
logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice")
serverNotificationSender, err := getSenderDevice(context.Background(), userAPI, cfg)
serverNotificationSender, err := getSenderDevice(context.Background(), rsAPI, userAPI, cfg)
if err != nil {
logrus.WithError(err).Fatal("unable to get account for sending sending server notices")
}

View file

@ -277,6 +277,7 @@ func (r sendServerNoticeRequest) valid() (ok bool) {
// It returns an userapi.Device, which is used for building the event
func getSenderDevice(
ctx context.Context,
rsAPI api.ClientRoomserverAPI,
userAPI userapi.ClientUserAPI,
cfg *config.ClientAPI,
) (*userapi.Device, error) {
@ -291,16 +292,32 @@ func getSenderDevice(
return nil, err
}
// set the avatarurl for the user
res := &userapi.PerformSetAvatarURLResponse{}
// Set the avatarurl for the user
avatarRes := &userapi.PerformSetAvatarURLResponse{}
if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{
Localpart: cfg.Matrix.ServerNotices.LocalPart,
AvatarURL: cfg.Matrix.ServerNotices.AvatarURL,
}, res); err != nil {
}, avatarRes); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed")
return nil, err
}
profile := avatarRes.Profile
// Set the displayname for the user
displayNameRes := &userapi.PerformUpdateDisplayNameResponse{}
if err = userAPI.SetDisplayName(ctx, &userapi.PerformUpdateDisplayNameRequest{
Localpart: cfg.Matrix.ServerNotices.LocalPart,
DisplayName: cfg.Matrix.ServerNotices.DisplayName,
}, displayNameRes); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.SetDisplayName failed")
return nil, err
}
if displayNameRes.Changed {
profile.DisplayName = cfg.Matrix.ServerNotices.DisplayName
}
// Check if we got existing devices
deviceRes := &userapi.QueryDevicesResponse{}
err = userAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{
@ -310,7 +327,15 @@ func getSenderDevice(
return nil, err
}
// We've got an existing account, return the first device of it
if len(deviceRes.Devices) > 0 {
// If there were changes to the profile, create a new membership event
if displayNameRes.Changed || avatarRes.Changed {
_, err = updateProfile(ctx, rsAPI, &deviceRes.Devices[0], profile, accRes.Account.UserID, cfg, time.Now())
if err != nil {
return nil, err
}
}
return &deviceRes.Devices[0], nil
}

View file

@ -231,9 +231,9 @@ GEM
jekyll-seo-tag (~> 2.1)
minitest (5.15.0)
multipart-post (2.1.1)
nokogiri (1.13.6-arm64-darwin)
nokogiri (1.13.9-arm64-darwin)
racc (~> 1.4)
nokogiri (1.13.6-x86_64-linux)
nokogiri (1.13.9-x86_64-linux)
racc (~> 1.4)
octokit (4.22.0)
faraday (>= 0.9)

View file

@ -116,17 +116,14 @@ func NewInternalAPI(
_ = federationDB.RemoveAllServersFromBlacklist()
}
stats := &statistics.Statistics{
DB: federationDB,
FailuresUntilBlacklist: cfg.FederationMaxRetries,
}
stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1)
js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
queues := queue.NewOutgoingQueues(
federationDB, base.ProcessContext,
cfg.Matrix.DisableFederation,
cfg.Matrix.ServerName, federation, rsAPI, stats,
cfg.Matrix.ServerName, federation, rsAPI, &stats,
&queue.SigningInfo{
KeyID: cfg.Matrix.KeyID,
PrivateKey: cfg.Matrix.PrivateKey,
@ -183,5 +180,5 @@ func NewInternalAPI(
}
time.AfterFunc(time.Minute, cleanExpiredEDUs)
return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, stats, caches, queues, keyRing)
return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, &stats, caches, queues, keyRing)
}

View file

@ -21,21 +21,22 @@ import (
"sync"
"time"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"go.uber.org/atomic"
fedapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"go.uber.org/atomic"
)
const (
maxPDUsPerTransaction = 50
maxEDUsPerTransaction = 50
maxEDUsPerTransaction = 100
maxPDUsInMemory = 128
maxEDUsInMemory = 128
queueIdleTimeout = time.Second * 30
@ -64,7 +65,6 @@ type destinationQueue struct {
pendingPDUs []*queuedPDU // PDUs waiting to be sent
pendingEDUs []*queuedEDU // EDUs waiting to be sent
pendingMutex sync.RWMutex // protects pendingPDUs and pendingEDUs
interruptBackoff chan bool // interrupts backoff
}
// Send event adds the event to the pending queue for the destination.
@ -75,39 +75,22 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re
logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination)
return
}
// Create a database entry that associates the given PDU NID with
// this destination queue. We'll then be able to retrieve the PDU
// later.
if err := oq.db.AssociatePDUWithDestination(
oq.process.Context(),
"", // TODO: remove this, as we don't need to persist the transaction ID
oq.destination, // the destination server name
receipt, // NIDs from federationapi_queue_json table
); err != nil {
logrus.WithError(err).Errorf("failed to associate PDU %q with destination %q", event.EventID(), oq.destination)
return
// If there's room in memory to hold the event then add it to the
// list.
oq.pendingMutex.Lock()
if len(oq.pendingPDUs) < maxPDUsInMemory {
oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{
pdu: event,
receipt: receipt,
})
} else {
oq.overflowed.Store(true)
}
// Check if the destination is blacklisted. If it isn't then wake
// up the queue.
if !oq.statistics.Blacklisted() {
// If there's room in memory to hold the event then add it to the
// list.
oq.pendingMutex.Lock()
if len(oq.pendingPDUs) < maxPDUsInMemory {
oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{
pdu: event,
receipt: receipt,
})
} else {
oq.overflowed.Store(true)
}
oq.pendingMutex.Unlock()
// Wake up the queue if it's asleep.
oq.wakeQueueIfNeeded()
select {
case oq.notify <- struct{}{}:
default:
}
oq.pendingMutex.Unlock()
if !oq.backingOff.Load() {
oq.wakeQueueAndNotify()
}
}
@ -119,40 +102,47 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share
logrus.Errorf("attempt to send nil EDU with destination %q", oq.destination)
return
}
// Create a database entry that associates the given PDU NID with
// this destination queue. We'll then be able to retrieve the PDU
// later.
if err := oq.db.AssociateEDUWithDestination(
oq.process.Context(),
oq.destination, // the destination server name
receipt, // NIDs from federationapi_queue_json table
event.Type,
nil, // this will use the default expireEDUTypes map
); err != nil {
logrus.WithError(err).Errorf("failed to associate EDU with destination %q", oq.destination)
return
// If there's room in memory to hold the event then add it to the
// list.
oq.pendingMutex.Lock()
if len(oq.pendingEDUs) < maxEDUsInMemory {
oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{
edu: event,
receipt: receipt,
})
} else {
oq.overflowed.Store(true)
}
// Check if the destination is blacklisted. If it isn't then wake
// up the queue.
if !oq.statistics.Blacklisted() {
// If there's room in memory to hold the event then add it to the
// list.
oq.pendingMutex.Lock()
if len(oq.pendingEDUs) < maxEDUsInMemory {
oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{
edu: event,
receipt: receipt,
})
} else {
oq.overflowed.Store(true)
}
oq.pendingMutex.Unlock()
// Wake up the queue if it's asleep.
oq.wakeQueueIfNeeded()
select {
case oq.notify <- struct{}{}:
default:
}
oq.pendingMutex.Unlock()
if !oq.backingOff.Load() {
oq.wakeQueueAndNotify()
}
}
// handleBackoffNotifier is registered as the backoff notification
// callback with Statistics. It will wakeup and notify the queue
// if the queue is currently backing off.
func (oq *destinationQueue) handleBackoffNotifier() {
// Only wake up the queue if it is backing off.
// Otherwise there is no pending work for the queue to handle
// so waking the queue would be a waste of resources.
if oq.backingOff.Load() {
oq.wakeQueueAndNotify()
}
}
// wakeQueueAndNotify ensures the destination queue is running and notifies it
// that there is pending work.
func (oq *destinationQueue) wakeQueueAndNotify() {
// Wake up the queue if it's asleep.
oq.wakeQueueIfNeeded()
// Notify the queue that there are events ready to send.
select {
case oq.notify <- struct{}{}:
default:
}
}
@ -161,10 +151,11 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share
// then we will interrupt the backoff, causing any federation
// requests to retry.
func (oq *destinationQueue) wakeQueueIfNeeded() {
// If we are backing off then interrupt the backoff.
// Clear the backingOff flag and update the backoff metrics if it was set.
if oq.backingOff.CompareAndSwap(true, false) {
oq.interruptBackoff <- true
destinationQueueBackingOff.Dec()
}
// If we aren't running then wake up the queue.
if !oq.running.Load() {
// Start the queue.
@ -196,38 +187,54 @@ func (oq *destinationQueue) getPendingFromDatabase() {
gotEDUs[edu.receipt.String()] = struct{}{}
}
overflowed := false
if pduCapacity := maxPDUsInMemory - len(oq.pendingPDUs); pduCapacity > 0 {
// We have room in memory for some PDUs - let's request no more than that.
if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, pduCapacity); err == nil {
if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, maxPDUsInMemory); err == nil {
if len(pdus) == maxPDUsInMemory {
overflowed = true
}
for receipt, pdu := range pdus {
if _, ok := gotPDUs[receipt.String()]; ok {
continue
}
oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{receipt, pdu})
retrieved = true
if len(oq.pendingPDUs) == maxPDUsInMemory {
break
}
}
} else {
logrus.WithError(err).Errorf("Failed to get pending PDUs for %q", oq.destination)
}
}
if eduCapacity := maxEDUsInMemory - len(oq.pendingEDUs); eduCapacity > 0 {
// We have room in memory for some EDUs - let's request no more than that.
if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, eduCapacity); err == nil {
if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, maxEDUsInMemory); err == nil {
if len(edus) == maxEDUsInMemory {
overflowed = true
}
for receipt, edu := range edus {
if _, ok := gotEDUs[receipt.String()]; ok {
continue
}
oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{receipt, edu})
retrieved = true
if len(oq.pendingEDUs) == maxEDUsInMemory {
break
}
}
} else {
logrus.WithError(err).Errorf("Failed to get pending EDUs for %q", oq.destination)
}
}
// If we've retrieved all of the events from the database with room to spare
// in memory then we'll no longer consider this queue to be overflowed.
if len(oq.pendingPDUs) < maxPDUsInMemory && len(oq.pendingEDUs) < maxEDUsInMemory {
if !overflowed {
oq.overflowed.Store(false)
} else {
}
// If we've retrieved some events then notify the destination queue goroutine.
if retrieved {
@ -238,6 +245,24 @@ func (oq *destinationQueue) getPendingFromDatabase() {
}
}
// checkNotificationsOnClose checks for any remaining notifications
// and starts a new backgroundSend goroutine if any exist.
func (oq *destinationQueue) checkNotificationsOnClose() {
// NOTE : If we are stopping the queue due to blacklist then it
// doesn't matter if we have been notified of new work since
// this queue instance will be deleted anyway.
if !oq.statistics.Blacklisted() {
select {
case <-oq.notify:
// We received a new notification in between the
// idle timeout firing and stopping the goroutine.
// Immediately restart the queue.
oq.wakeQueueAndNotify()
default:
}
}
}
// backgroundSend is the worker goroutine for sending events.
func (oq *destinationQueue) backgroundSend() {
// Check if a worker is already running, and if it isn't, then
@ -245,10 +270,17 @@ func (oq *destinationQueue) backgroundSend() {
if !oq.running.CompareAndSwap(false, true) {
return
}
// Register queue cleanup functions.
// NOTE : The ordering here is very intentional.
defer oq.checkNotificationsOnClose()
defer oq.running.Store(false)
destinationQueueRunning.Inc()
defer destinationQueueRunning.Dec()
defer oq.queues.clearQueue(oq)
defer oq.running.Store(false)
idleTimeout := time.NewTimer(queueIdleTimeout)
defer idleTimeout.Stop()
// Mark the queue as overflowed, so we will consult the database
// to see if there's anything new to send.
@ -261,59 +293,33 @@ func (oq *destinationQueue) backgroundSend() {
oq.getPendingFromDatabase()
}
// Reset the queue idle timeout.
if !idleTimeout.Stop() {
select {
case <-idleTimeout.C:
default:
}
}
idleTimeout.Reset(queueIdleTimeout)
// If we have nothing to do then wait either for incoming events, or
// until we hit an idle timeout.
select {
case <-oq.notify:
// There's work to do, either because getPendingFromDatabase
// told us there is, or because a new event has come in via
// sendEvent/sendEDU.
case <-time.After(queueIdleTimeout):
// told us there is, a new event has come in via sendEvent/sendEDU,
// or we are backing off and it is time to retry.
case <-idleTimeout.C:
// The worker is idle so stop the goroutine. It'll get
// restarted automatically the next time we have an event to
// send.
return
case <-oq.process.Context().Done():
// The parent process is shutting down, so stop.
oq.statistics.ClearBackoff()
return
}
// If we are backing off this server then wait for the
// backoff duration to complete first, or until explicitly
// told to retry.
until, blacklisted := oq.statistics.BackoffInfo()
if blacklisted {
// It's been suggested that we should give up because the backoff
// has exceeded a maximum allowable value. Clean up the in-memory
// buffers at this point. The PDU clean-up is already on a defer.
logrus.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination)
oq.pendingMutex.Lock()
for i := range oq.pendingPDUs {
oq.pendingPDUs[i] = nil
}
for i := range oq.pendingEDUs {
oq.pendingEDUs[i] = nil
}
oq.pendingPDUs = nil
oq.pendingEDUs = nil
oq.pendingMutex.Unlock()
return
}
if until != nil && until.After(time.Now()) {
// We haven't backed off yet, so wait for the suggested amount of
// time.
duration := time.Until(*until)
logrus.Debugf("Backing off %q for %s", oq.destination, duration)
oq.backingOff.Store(true)
destinationQueueBackingOff.Inc()
select {
case <-time.After(duration):
case <-oq.interruptBackoff:
}
destinationQueueBackingOff.Dec()
oq.backingOff.Store(false)
}
// Work out which PDUs/EDUs to include in the next transaction.
oq.pendingMutex.RLock()
pduCount := len(oq.pendingPDUs)
@ -328,99 +334,52 @@ func (oq *destinationQueue) backgroundSend() {
toSendEDUs := oq.pendingEDUs[:eduCount]
oq.pendingMutex.RUnlock()
// If we didn't get anything from the database and there are no
// pending EDUs then there's nothing to do - stop here.
if pduCount == 0 && eduCount == 0 {
continue
}
// If we have pending PDUs or EDUs then construct a transaction.
// Try sending the next transaction and see what happens.
transaction, pc, ec, terr := oq.nextTransaction(toSendPDUs, toSendEDUs)
terr := oq.nextTransaction(toSendPDUs, toSendEDUs)
if terr != nil {
// We failed to send the transaction. Mark it as a failure.
oq.statistics.Failure()
} else if transaction {
// If we successfully sent the transaction then clear out
// the pending events and EDUs, and wipe our transaction ID.
oq.statistics.Success()
oq.pendingMutex.Lock()
for i := range oq.pendingPDUs[:pc] {
oq.pendingPDUs[i] = nil
_, blacklisted := oq.statistics.Failure()
if !blacklisted {
// Register the backoff state and exit the goroutine.
// It'll get restarted automatically when the backoff
// completes.
oq.backingOff.Store(true)
destinationQueueBackingOff.Inc()
return
} else {
// Immediately trigger the blacklist logic.
oq.blacklistDestination()
return
}
for i := range oq.pendingEDUs[:ec] {
oq.pendingEDUs[i] = nil
}
oq.pendingPDUs = oq.pendingPDUs[pc:]
oq.pendingEDUs = oq.pendingEDUs[ec:]
oq.pendingMutex.Unlock()
} else {
oq.handleTransactionSuccess(pduCount, eduCount)
}
}
}
// nextTransaction creates a new transaction from the pending event
// queue and sends it. Returns true if a transaction was sent or
// false otherwise.
// queue and sends it.
// Returns an error if the transaction wasn't sent.
func (oq *destinationQueue) nextTransaction(
pdus []*queuedPDU,
edus []*queuedEDU,
) (bool, int, int, error) {
// If there's no projected transaction ID then generate one. If
// the transaction succeeds then we'll set it back to "" so that
// we generate a new one next time. If it fails, we'll preserve
// it so that we retry with the same transaction ID.
oq.transactionIDMutex.Lock()
if oq.transactionID == "" {
now := gomatrixserverlib.AsTimestamp(time.Now())
oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
}
oq.transactionIDMutex.Unlock()
) error {
// Create the transaction.
t := gomatrixserverlib.Transaction{
PDUs: []json.RawMessage{},
EDUs: []gomatrixserverlib.EDU{},
}
t.Origin = oq.origin
t.Destination = oq.destination
t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now())
t.TransactionID = oq.transactionID
// If we didn't get anything from the database and there are no
// pending EDUs then there's nothing to do - stop here.
if len(pdus) == 0 && len(edus) == 0 {
return false, 0, 0, nil
}
var pduReceipts []*shared.Receipt
var eduReceipts []*shared.Receipt
// Go through PDUs that we retrieved from the database, if any,
// and add them into the transaction.
for _, pdu := range pdus {
if pdu == nil || pdu.pdu == nil {
continue
}
// Append the JSON of the event, since this is a json.RawMessage type in the
// gomatrixserverlib.Transaction struct
t.PDUs = append(t.PDUs, pdu.pdu.JSON())
pduReceipts = append(pduReceipts, pdu.receipt)
}
// Do the same for pending EDUS in the queue.
for _, edu := range edus {
if edu == nil || edu.edu == nil {
continue
}
t.EDUs = append(t.EDUs, *edu.edu)
eduReceipts = append(eduReceipts, edu.receipt)
}
t, pduReceipts, eduReceipts := oq.createTransaction(pdus, edus)
logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs))
// Try to send the transaction to the destination server.
// TODO: we should check for 500-ish fails vs 400-ish here,
// since we shouldn't queue things indefinitely in response
// to a 400-ish error
ctx, cancel := context.WithTimeout(oq.process.Context(), time.Minute*5)
defer cancel()
_, err := oq.client.SendTransaction(ctx, t)
switch err.(type) {
switch errResponse := err.(type) {
case nil:
// Clean up the transaction in the database.
if pduReceipts != nil {
@ -439,16 +398,129 @@ func (oq *destinationQueue) nextTransaction(
oq.transactionIDMutex.Lock()
oq.transactionID = ""
oq.transactionIDMutex.Unlock()
return true, len(t.PDUs), len(t.EDUs), nil
return nil
case gomatrix.HTTPError:
// Report that we failed to send the transaction and we
// will retry again, subject to backoff.
return false, 0, 0, err
// TODO: we should check for 500-ish fails vs 400-ish here,
// since we shouldn't queue things indefinitely in response
// to a 400-ish error
code := errResponse.Code
logrus.Debug("Transaction failed with HTTP", code)
return err
default:
logrus.WithFields(logrus.Fields{
"destination": oq.destination,
logrus.ErrorKey: err,
}).Debugf("Failed to send transaction %q", t.TransactionID)
return false, 0, 0, err
return err
}
}
// createTransaction generates a gomatrixserverlib.Transaction from the provided pdus and edus.
// It also returns the associated event receipts so they can be cleaned from the database in
// the case of a successful transaction.
func (oq *destinationQueue) createTransaction(
pdus []*queuedPDU,
edus []*queuedEDU,
) (gomatrixserverlib.Transaction, []*shared.Receipt, []*shared.Receipt) {
// If there's no projected transaction ID then generate one. If
// the transaction succeeds then we'll set it back to "" so that
// we generate a new one next time. If it fails, we'll preserve
// it so that we retry with the same transaction ID.
oq.transactionIDMutex.Lock()
if oq.transactionID == "" {
now := gomatrixserverlib.AsTimestamp(time.Now())
oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
}
oq.transactionIDMutex.Unlock()
t := gomatrixserverlib.Transaction{
PDUs: []json.RawMessage{},
EDUs: []gomatrixserverlib.EDU{},
}
t.Origin = oq.origin
t.Destination = oq.destination
t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now())
t.TransactionID = oq.transactionID
var pduReceipts []*shared.Receipt
var eduReceipts []*shared.Receipt
// Go through PDUs that we retrieved from the database, if any,
// and add them into the transaction.
for _, pdu := range pdus {
// These should never be nil.
if pdu == nil || pdu.pdu == nil {
continue
}
// Append the JSON of the event, since this is a json.RawMessage type in the
// gomatrixserverlib.Transaction struct
t.PDUs = append(t.PDUs, pdu.pdu.JSON())
pduReceipts = append(pduReceipts, pdu.receipt)
}
// Do the same for pending EDUS in the queue.
for _, edu := range edus {
// These should never be nil.
if edu == nil || edu.edu == nil {
continue
}
t.EDUs = append(t.EDUs, *edu.edu)
eduReceipts = append(eduReceipts, edu.receipt)
}
return t, pduReceipts, eduReceipts
}
// blacklistDestination removes all pending PDUs and EDUs that have been cached
// and deletes this queue.
func (oq *destinationQueue) blacklistDestination() {
// It's been suggested that we should give up because the backoff
// has exceeded a maximum allowable value. Clean up the in-memory
// buffers at this point. The PDU clean-up is already on a defer.
logrus.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination)
oq.pendingMutex.Lock()
for i := range oq.pendingPDUs {
oq.pendingPDUs[i] = nil
}
for i := range oq.pendingEDUs {
oq.pendingEDUs[i] = nil
}
oq.pendingPDUs = nil
oq.pendingEDUs = nil
oq.pendingMutex.Unlock()
// Delete this queue as no more messages will be sent to this
// destination until it is no longer blacklisted.
oq.statistics.AssignBackoffNotifier(nil)
oq.queues.clearQueue(oq)
}
// handleTransactionSuccess updates the cached event queues as well as the success and
// backoff information for this server.
func (oq *destinationQueue) handleTransactionSuccess(pduCount int, eduCount int) {
// If we successfully sent the transaction then clear out
// the pending events and EDUs, and wipe our transaction ID.
oq.statistics.Success()
oq.pendingMutex.Lock()
defer oq.pendingMutex.Unlock()
for i := range oq.pendingPDUs[:pduCount] {
oq.pendingPDUs[i] = nil
}
for i := range oq.pendingEDUs[:eduCount] {
oq.pendingEDUs[i] = nil
}
oq.pendingPDUs = oq.pendingPDUs[pduCount:]
oq.pendingEDUs = oq.pendingEDUs[eduCount:]
if len(oq.pendingPDUs) > 0 || len(oq.pendingEDUs) > 0 {
select {
case oq.notify <- struct{}{}:
default:
}
}
}

View file

@ -24,6 +24,7 @@ import (
"github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
@ -162,23 +163,25 @@ func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *d
if !ok || oq == nil {
destinationQueueTotal.Inc()
oq = &destinationQueue{
queues: oqs,
db: oqs.db,
process: oqs.process,
rsAPI: oqs.rsAPI,
origin: oqs.origin,
destination: destination,
client: oqs.client,
statistics: oqs.statistics.ForServer(destination),
notify: make(chan struct{}, 1),
interruptBackoff: make(chan bool),
signing: oqs.signing,
queues: oqs,
db: oqs.db,
process: oqs.process,
rsAPI: oqs.rsAPI,
origin: oqs.origin,
destination: destination,
client: oqs.client,
statistics: oqs.statistics.ForServer(destination),
notify: make(chan struct{}, 1),
signing: oqs.signing,
}
oq.statistics.AssignBackoffNotifier(oq.handleBackoffNotifier)
oqs.queues[destination] = oq
}
return oq
}
// clearQueue removes the queue for the provided destination from the
// set of destination queues.
func (oqs *OutgoingQueues) clearQueue(oq *destinationQueue) {
oqs.queuesMutex.Lock()
defer oqs.queuesMutex.Unlock()
@ -245,11 +248,25 @@ func (oqs *OutgoingQueues) SendEvent(
}
for destination := range destmap {
if queue := oqs.getQueue(destination); queue != nil {
if queue := oqs.getQueue(destination); queue != nil && !queue.statistics.Blacklisted() {
queue.sendEvent(ev, nid)
} else {
delete(destmap, destination)
}
}
// Create a database entry that associates the given PDU NID with
// this destinations queue. We'll then be able to retrieve the PDU
// later.
if err := oqs.db.AssociatePDUWithDestinations(
oqs.process.Context(),
destmap,
nid, // NIDs from federationapi_queue_json table
); err != nil {
logrus.WithError(err).Errorf("failed to associate PDUs %q with destinations", nid)
return err
}
return nil
}
@ -319,11 +336,27 @@ func (oqs *OutgoingQueues) SendEDU(
}
for destination := range destmap {
if queue := oqs.getQueue(destination); queue != nil {
if queue := oqs.getQueue(destination); queue != nil && !queue.statistics.Blacklisted() {
queue.sendEDU(e, nid)
} else {
delete(destmap, destination)
}
}
// Create a database entry that associates the given PDU NID with
// this destination queue. We'll then be able to retrieve the PDU
// later.
if err := oqs.db.AssociateEDUWithDestinations(
oqs.process.Context(),
destmap, // the destination server name
nid, // NIDs from federationapi_queue_json table
e.Type,
nil, // this will use the default expireEDUTypes map
); err != nil {
logrus.WithError(err).Errorf("failed to associate EDU with destinations")
return err
}
return nil
}
@ -332,7 +365,9 @@ func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) {
if oqs.disabled {
return
}
oqs.statistics.ForServer(srv).RemoveBlacklist()
if queue := oqs.getQueue(srv); queue != nil {
queue.statistics.ClearBackoff()
queue.wakeQueueIfNeeded()
}
}

File diff suppressed because it is too large Load diff

View file

@ -2,6 +2,7 @@ package statistics
import (
"math"
"math/rand"
"sync"
"time"
@ -20,12 +21,23 @@ type Statistics struct {
servers map[gomatrixserverlib.ServerName]*ServerStatistics
mutex sync.RWMutex
backoffTimers map[gomatrixserverlib.ServerName]*time.Timer
backoffMutex sync.RWMutex
// How many times should we tolerate consecutive failures before we
// just blacklist the host altogether? The backoff is exponential,
// so the max time here to attempt is 2**failures seconds.
FailuresUntilBlacklist uint32
}
func NewStatistics(db storage.Database, failuresUntilBlacklist uint32) Statistics {
return Statistics{
DB: db,
FailuresUntilBlacklist: failuresUntilBlacklist,
backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer),
}
}
// ForServer returns server statistics for the given server name. If it
// does not exist, it will create empty statistics and return those.
func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerStatistics {
@ -45,7 +57,6 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS
server = &ServerStatistics{
statistics: s,
serverName: serverName,
interrupt: make(chan struct{}),
}
s.servers[serverName] = server
s.mutex.Unlock()
@ -64,29 +75,43 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS
// many times we failed etc. It also manages the backoff time and black-
// listing a remote host if it remains uncooperative.
type ServerStatistics struct {
statistics *Statistics //
serverName gomatrixserverlib.ServerName //
blacklisted atomic.Bool // is the node blacklisted
backoffStarted atomic.Bool // is the backoff started
backoffUntil atomic.Value // time.Time until this backoff interval ends
backoffCount atomic.Uint32 // number of times BackoffDuration has been called
interrupt chan struct{} // interrupts the backoff goroutine
successCounter atomic.Uint32 // how many times have we succeeded?
statistics *Statistics //
serverName gomatrixserverlib.ServerName //
blacklisted atomic.Bool // is the node blacklisted
backoffStarted atomic.Bool // is the backoff started
backoffUntil atomic.Value // time.Time until this backoff interval ends
backoffCount atomic.Uint32 // number of times BackoffDuration has been called
successCounter atomic.Uint32 // how many times have we succeeded?
backoffNotifier func() // notifies destination queue when backoff completes
notifierMutex sync.Mutex
}
const maxJitterMultiplier = 1.4
const minJitterMultiplier = 0.8
// duration returns how long the next backoff interval should be.
func (s *ServerStatistics) duration(count uint32) time.Duration {
return time.Second * time.Duration(math.Exp2(float64(count)))
// Add some jitter to minimise the chance of having multiple backoffs
// ending at the same time.
jitter := rand.Float64()*(maxJitterMultiplier-minJitterMultiplier) + minJitterMultiplier
duration := time.Millisecond * time.Duration(math.Exp2(float64(count))*jitter*1000)
return duration
}
// cancel will interrupt the currently active backoff.
func (s *ServerStatistics) cancel() {
s.blacklisted.Store(false)
s.backoffUntil.Store(time.Time{})
select {
case s.interrupt <- struct{}{}:
default:
}
s.ClearBackoff()
}
// AssignBackoffNotifier configures the channel to send to when
// a backoff completes.
func (s *ServerStatistics) AssignBackoffNotifier(notifier func()) {
s.notifierMutex.Lock()
defer s.notifierMutex.Unlock()
s.backoffNotifier = notifier
}
// Success updates the server statistics with a new successful
@ -95,8 +120,8 @@ func (s *ServerStatistics) cancel() {
// we will unblacklist it.
func (s *ServerStatistics) Success() {
s.cancel()
s.successCounter.Inc()
s.backoffCount.Store(0)
s.successCounter.Inc()
if s.statistics.DB != nil {
if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil {
logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName)
@ -105,13 +130,17 @@ func (s *ServerStatistics) Success() {
}
// Failure marks a failure and starts backing off if needed.
// The next call to BackoffIfRequired will do the right thing
// after this. It will return the time that the current failure
// It will return the time that the current failure
// will result in backoff waiting until, and a bool signalling
// whether we have blacklisted and therefore to give up.
func (s *ServerStatistics) Failure() (time.Time, bool) {
// Return immediately if we have blacklisted this node.
if s.blacklisted.Load() {
return time.Time{}, true
}
// If we aren't already backing off, this call will start
// a new backoff period. Increase the failure counter and
// a new backoff period, increase the failure counter and
// start a goroutine which will wait out the backoff and
// unset the backoffStarted flag when done.
if s.backoffStarted.CompareAndSwap(false, true) {
@ -122,40 +151,48 @@ func (s *ServerStatistics) Failure() (time.Time, bool) {
logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName)
}
}
s.ClearBackoff()
return time.Time{}, true
}
go func() {
until, ok := s.backoffUntil.Load().(time.Time)
if ok && !until.IsZero() {
select {
case <-time.After(time.Until(until)):
case <-s.interrupt:
}
s.backoffStarted.Store(false)
}
}()
// We're starting a new back off so work out what the next interval
// will be.
count := s.backoffCount.Load()
until := time.Now().Add(s.duration(count))
s.backoffUntil.Store(until)
s.statistics.backoffMutex.Lock()
defer s.statistics.backoffMutex.Unlock()
s.statistics.backoffTimers[s.serverName] = time.AfterFunc(time.Until(until), s.backoffFinished)
}
// Check if we have blacklisted this node.
if s.blacklisted.Load() {
return time.Now(), true
}
return s.backoffUntil.Load().(time.Time), false
}
// If we're already backing off and we haven't yet surpassed
// the deadline then return that. Repeated calls to Failure
// within a single backoff interval will have no side effects.
if until, ok := s.backoffUntil.Load().(time.Time); ok && !time.Now().After(until) {
return until, false
// ClearBackoff stops the backoff timer for this destination if it is running
// and removes the timer from the backoffTimers map.
func (s *ServerStatistics) ClearBackoff() {
// If the timer is still running then stop it so it's memory is cleaned up sooner.
s.statistics.backoffMutex.Lock()
defer s.statistics.backoffMutex.Unlock()
if timer, ok := s.statistics.backoffTimers[s.serverName]; ok {
timer.Stop()
}
delete(s.statistics.backoffTimers, s.serverName)
// We're either backing off and have passed the deadline, or
// we aren't backing off, so work out what the next interval
// will be.
count := s.backoffCount.Load()
until := time.Now().Add(s.duration(count))
s.backoffUntil.Store(until)
return until, false
s.backoffStarted.Store(false)
}
// backoffFinished will clear the previous backoff and notify the destination queue.
func (s *ServerStatistics) backoffFinished() {
s.ClearBackoff()
// Notify the destinationQueue if one is currently running.
s.notifierMutex.Lock()
defer s.notifierMutex.Unlock()
if s.backoffNotifier != nil {
s.backoffNotifier()
}
}
// BackoffInfo returns information about the current or previous backoff.
@ -174,6 +211,12 @@ func (s *ServerStatistics) Blacklisted() bool {
return s.blacklisted.Load()
}
// RemoveBlacklist removes the blacklisted status from the server.
func (s *ServerStatistics) RemoveBlacklist() {
s.cancel()
s.backoffCount.Store(0)
}
// SuccessCount returns the number of successful requests. This is
// usually useful in constructing transaction IDs.
func (s *ServerStatistics) SuccessCount() uint32 {

View file

@ -7,9 +7,7 @@ import (
)
func TestBackoff(t *testing.T) {
stats := Statistics{
FailuresUntilBlacklist: 7,
}
stats := NewStatistics(nil, 7)
server := ServerStatistics{
statistics: &stats,
serverName: "test.com",
@ -36,7 +34,7 @@ func TestBackoff(t *testing.T) {
// Get the duration.
_, blacklist := server.BackoffInfo()
duration := time.Until(until).Round(time.Second)
duration := time.Until(until)
// Unset the backoff, or otherwise our next call will think that
// there's a backoff in progress and return the same result.
@ -57,8 +55,17 @@ func TestBackoff(t *testing.T) {
// Check if the duration is what we expect.
t.Logf("Backoff %d is for %s", i, duration)
if wanted := time.Second * time.Duration(math.Exp2(float64(i))); !blacklist && duration != wanted {
t.Fatalf("Backoff %d should have been %s but was %s", i, wanted, duration)
roundingAllowance := 0.01
minDuration := time.Millisecond * time.Duration(math.Exp2(float64(i))*minJitterMultiplier*1000-roundingAllowance)
maxDuration := time.Millisecond * time.Duration(math.Exp2(float64(i))*maxJitterMultiplier*1000+roundingAllowance)
var inJitterRange bool
if duration >= minDuration && duration <= maxDuration {
inJitterRange = true
} else {
inJitterRange = false
}
if !blacklist && !inJitterRange {
t.Fatalf("Backoff %d should have been between %s and %s but was %s", i, minDuration, maxDuration, duration)
}
}
}

View file

@ -18,9 +18,10 @@ import (
"context"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/federationapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
type Database interface {
@ -38,8 +39,8 @@ type Database interface {
GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error)
GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error)
AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error
AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error
AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error
AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error
CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error

View file

@ -52,6 +52,10 @@ type Receipt struct {
nid int64
}
func NewReceipt(nid int64) Receipt {
return Receipt{nid: nid}
}
func (r *Receipt) String() string {
return fmt.Sprintf("%d", r.nid)
}

View file

@ -38,9 +38,9 @@ var defaultExpireEDUTypes = map[string]time.Duration{
// AssociateEDUWithDestination creates an association that the
// destination queues will use to determine which JSON blobs to send
// to which servers.
func (d *Database) AssociateEDUWithDestination(
func (d *Database) AssociateEDUWithDestinations(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
destinations map[gomatrixserverlib.ServerName]struct{},
receipt *Receipt,
eduType string,
expireEDUTypes map[string]time.Duration,
@ -59,17 +59,18 @@ func (d *Database) AssociateEDUWithDestination(
expiresAt = 0
}
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.FederationQueueEDUs.InsertQueueEDU(
ctx, // context
txn, // SQL transaction
eduType, // EDU type for coalescing
serverName, // destination server name
receipt.nid, // NID from the federationapi_queue_json table
expiresAt, // The timestamp this EDU will expire
); err != nil {
return fmt.Errorf("InsertQueueEDU: %w", err)
var err error
for destination := range destinations {
err = d.FederationQueueEDUs.InsertQueueEDU(
ctx, // context
txn, // SQL transaction
eduType, // EDU type for coalescing
destination, // destination server name
receipt.nid, // NID from the federationapi_queue_json table
expiresAt, // The timestamp this EDU will expire
)
}
return nil
return err
})
}

View file

@ -27,23 +27,23 @@ import (
// AssociatePDUWithDestination creates an association that the
// destination queues will use to determine which JSON blobs to send
// to which servers.
func (d *Database) AssociatePDUWithDestination(
func (d *Database) AssociatePDUWithDestinations(
ctx context.Context,
transactionID gomatrixserverlib.TransactionID,
serverName gomatrixserverlib.ServerName,
destinations map[gomatrixserverlib.ServerName]struct{},
receipt *Receipt,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.FederationQueuePDUs.InsertQueuePDU(
ctx, // context
txn, // SQL transaction
transactionID, // transaction ID
serverName, // destination server name
receipt.nid, // NID from the federationapi_queue_json table
); err != nil {
return fmt.Errorf("InsertQueuePDU: %w", err)
var err error
for destination := range destinations {
err = d.FederationQueuePDUs.InsertQueuePDU(
ctx, // context
txn, // SQL transaction
"", // transaction ID
destination, // destination server name
receipt.nid, // NID from the federationapi_queue_json table
)
}
return nil
return err
})
}

View file

@ -35,6 +35,7 @@ func TestExpireEDUs(t *testing.T) {
}
ctx := context.Background()
destinations := map[gomatrixserverlib.ServerName]struct{}{"localhost": {}}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateFederationDatabase(t, dbType)
defer close()
@ -43,7 +44,7 @@ func TestExpireEDUs(t *testing.T) {
receipt, err := db.StoreJSON(ctx, "{}")
assert.NoError(t, err)
err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, gomatrixserverlib.MReceipt, expireEDUTypes)
err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, gomatrixserverlib.MReceipt, expireEDUTypes)
assert.NoError(t, err)
}
// add data without expiry
@ -51,7 +52,7 @@ func TestExpireEDUs(t *testing.T) {
assert.NoError(t, err)
// m.read_marker gets the default expiry of 24h, so won't be deleted further down in this test
err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, "m.read_marker", expireEDUTypes)
err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, "m.read_marker", expireEDUTypes)
assert.NoError(t, err)
// Delete expired EDUs
@ -67,7 +68,7 @@ func TestExpireEDUs(t *testing.T) {
receipt, err = db.StoreJSON(ctx, "{}")
assert.NoError(t, err)
err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, gomatrixserverlib.MDirectToDevice, expireEDUTypes)
err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, gomatrixserverlib.MDirectToDevice, expireEDUTypes)
assert.NoError(t, err)
err = db.DeleteExpiredEDUs(ctx)

4
go.mod
View file

@ -22,7 +22,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20221018085104-a72a83f0e19a
github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a
github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.15
@ -50,6 +50,7 @@ require (
golang.org/x/term v0.0.0-20220919170432-7a66f970e087
gopkg.in/h2non/bimg.v1 v1.1.9
gopkg.in/yaml.v2 v2.4.0
gotest.tools/v3 v3.4.0
nhooyr.io/websocket v1.8.7
)
@ -127,7 +128,6 @@ require (
gopkg.in/macaroon.v2 v2.1.0 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gotest.tools/v3 v3.4.0 // indirect
)
go 1.18

4
go.sum
View file

@ -387,8 +387,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20221018085104-a72a83f0e19a h1:bQKHk3AWlgm7XhzPhuU3Iw3pUptW5l1DR/1y0o7zCKQ=
github.com/matrix-org/gomatrixserverlib v0.0.0-20221018085104-a72a83f0e19a/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4=
github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a h1:6rJFN5NBuzZ7h5meYkLtXKa6VFZfDc8oVXHd4SDXr5o=
github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4=
github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3 h1:lzkSQvBv8TuqKJCPoVwOVvEnARTlua5rrNy/Qw2Vxeo=
github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=

View file

@ -17,7 +17,7 @@ var build string
const (
VersionMajor = 0
VersionMinor = 10
VersionPatch = 3
VersionPatch = 4
VersionTag = "" // example: "rc1"
)

View file

@ -250,15 +250,13 @@ func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *ap
// nolint:gocyclo
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error {
var respMu sync.Mutex
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
res.UserSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
res.Failures = make(map[string]interface{})
// get cross-signing keys from the database
a.crossSigningKeysFromDatabase(ctx, req, res)
// make a map from domain to device keys
domainToDeviceKeys := make(map[string]map[string][]string)
domainToCrossSigningKeys := make(map[string]map[string]struct{})
@ -329,12 +327,16 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
}
// attempt to satisfy key queries from the local database first as we should get device updates pushed to us
domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys)
domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, &respMu, domainToDeviceKeys)
if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 {
// perform key queries for remote devices
a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys)
}
// Now that we've done the potentially expensive work of asking the federation,
// try filling the cross-signing keys from the database that we know about.
a.crossSigningKeysFromDatabase(ctx, req, res)
// Finally, append signatures that we know about
// TODO: This is horrible because we need to round-trip the signature from
// JSON, add the signatures and marshal it again, for some reason?
@ -407,7 +409,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
}
func (a *KeyInternalAPI) remoteKeysFromDatabase(
ctx context.Context, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string,
ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string,
) map[string]map[string][]string {
fetchRemote := make(map[string]map[string][]string)
for domain, userToDeviceMap := range domainToDeviceKeys {
@ -415,7 +417,7 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase(
// we can't safely return keys from the db when all devices are requested as we don't
// know if one has just been added.
if len(deviceIDs) > 0 {
err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, deviceIDs)
err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, deviceIDs)
if err == nil {
continue
}
@ -471,7 +473,9 @@ func (a *KeyInternalAPI) queryRemoteKeys(
close(resultCh)
}()
for result := range resultCh {
processResult := func(result *gomatrixserverlib.RespQueryKeys) {
respMu.Lock()
defer respMu.Unlock()
for userID, nest := range result.DeviceKeys {
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
for deviceID, deviceKey := range nest {
@ -494,6 +498,10 @@ func (a *KeyInternalAPI) queryRemoteKeys(
// TODO: do we want to persist these somewhere now
// that we have fetched them?
}
for result := range resultCh {
processResult(result)
}
}
func (a *KeyInternalAPI) queryRemoteKeysOnServer(
@ -541,9 +549,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
}
// refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this
// user so the fact that we're populating all devices here isn't a problem so long as we have devices.
respMu.Lock()
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, nil)
respMu.Unlock()
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil)
if err != nil {
logrus.WithFields(logrus.Fields{
logrus.ErrorKey: err,
@ -567,25 +573,26 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
res.Failures[serverName] = map[string]interface{}{
"message": err.Error(),
}
respMu.Unlock()
// last ditch, use the cache only. This is good for when clients hit /keys/query and the remote server
// is down, better to return something than nothing at all. Clients can know about the failure by
// inspecting the failures map though so they can know it's a cached response.
for userID, dkeys := range devKeys {
// drop the error as it's already a failure at this point
_ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, dkeys)
_ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, dkeys)
}
// Sytest expects no failures, if we still could retrieve keys, e.g. from local cache
respMu.Lock()
if len(res.DeviceKeys) > 0 {
delete(res.Failures, serverName)
}
respMu.Unlock()
}
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string,
ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string,
) error {
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
// if we can't query the db or there are fewer keys than requested, fetch from remote.
@ -598,9 +605,11 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
if len(deviceIDs) == 0 && len(keys) == 0 {
return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID)
}
respMu.Lock()
if res.DeviceKeys[userID] == nil {
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
}
respMu.Unlock()
for _, key := range keys {
if len(key.KeyJSON) == 0 {
@ -610,7 +619,9 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
DisplayName string `json:"device_display_name,omitempty"`
}{key.DisplayName})
respMu.Lock()
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
respMu.Unlock()
}
return nil
}

View file

@ -42,7 +42,7 @@ CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_s
const selectCrossSigningSigsForTargetSQL = "" +
"SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" +
" WHERE (origin_user_id = $1 OR origin_user_id = target_user_id) AND target_user_id = $2 AND target_key_id = $3"
" WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $2 AND target_key_id = $3"
const upsertCrossSigningSigsForTargetSQL = "" +
"INSERT INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +

View file

@ -42,7 +42,7 @@ CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_s
const selectCrossSigningSigsForTargetSQL = "" +
"SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" +
" WHERE (origin_user_id = $1 OR origin_user_id = target_user_id) AND target_user_id = $2 AND target_key_id = $3"
" WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $3 AND target_key_id = $4"
const upsertCrossSigningSigsForTargetSQL = "" +
"INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +
@ -85,7 +85,7 @@ func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error)
func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget(
ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID,
) (r types.CrossSigningSigMap, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetKeyID)
rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetUserID, targetKeyID)
if err != nil {
return nil, err
}

View file

@ -187,7 +187,7 @@ type ServerNotices struct {
// The displayname to be used when sending notices
DisplayName string `yaml:"display_name"`
// The avatar of this user
AvatarURL string `yaml:"avatar"`
AvatarURL string `yaml:"avatar_url"`
// The roomname to be used when creating messages
RoomName string `yaml:"room_name"`
}

View file

@ -428,6 +428,13 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
return
}
// Only notify clients about retired invite events, if the user didn't accept the invite.
// The PDU stream will also receive an event about accepting the invitation, so there should
// be a "smooth" transition from invite -> join, and not invite -> leave -> join
if msg.Membership == gomatrixserverlib.Join {
return
}
// Notify any active sync requests that the invite has been retired.
s.inviteStream.Advance(pduPos)
s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, msg.TargetUserID)

View file

@ -28,8 +28,9 @@ import (
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
const outputRoomEventsSchema = `
@ -133,7 +134,7 @@ const updateEventJSONSQL = "" +
"UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2"
// In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id).
const selectStateInRangeSQL = "" +
const selectStateInRangeFilteredSQL = "" +
"SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids, history_visibility" +
" FROM syncapi_output_room_events" +
" WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" +
@ -146,6 +147,15 @@ const selectStateInRangeSQL = "" +
" ORDER BY id ASC" +
" LIMIT $9"
// In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id).
const selectStateInRangeSQL = "" +
"SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids, history_visibility" +
" FROM syncapi_output_room_events" +
" WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" +
" AND room_id = ANY($3)" +
" ORDER BY id ASC" +
" LIMIT $4"
const deleteEventsForRoomSQL = "" +
"DELETE FROM syncapi_output_room_events WHERE room_id = $1"
@ -171,20 +181,21 @@ const selectContextAfterEventSQL = "" +
const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE id > $1 AND type = ANY($2) ORDER BY id ASC LIMIT $3"
type outputRoomEventsStatements struct {
insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt
selectEventsWitFilterStmt *sql.Stmt
selectMaxEventIDStmt *sql.Stmt
selectRecentEventsStmt *sql.Stmt
selectRecentEventsForSyncStmt *sql.Stmt
selectEarlyEventsStmt *sql.Stmt
selectStateInRangeStmt *sql.Stmt
updateEventJSONStmt *sql.Stmt
deleteEventsForRoomStmt *sql.Stmt
selectContextEventStmt *sql.Stmt
selectContextBeforeEventStmt *sql.Stmt
selectContextAfterEventStmt *sql.Stmt
selectSearchStmt *sql.Stmt
insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt
selectEventsWitFilterStmt *sql.Stmt
selectMaxEventIDStmt *sql.Stmt
selectRecentEventsStmt *sql.Stmt
selectRecentEventsForSyncStmt *sql.Stmt
selectEarlyEventsStmt *sql.Stmt
selectStateInRangeFilteredStmt *sql.Stmt
selectStateInRangeStmt *sql.Stmt
updateEventJSONStmt *sql.Stmt
deleteEventsForRoomStmt *sql.Stmt
selectContextEventStmt *sql.Stmt
selectContextBeforeEventStmt *sql.Stmt
selectContextAfterEventStmt *sql.Stmt
selectSearchStmt *sql.Stmt
}
func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
@ -214,6 +225,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
{&s.selectRecentEventsStmt, selectRecentEventsSQL},
{&s.selectRecentEventsForSyncStmt, selectRecentEventsForSyncSQL},
{&s.selectEarlyEventsStmt, selectEarlyEventsSQL},
{&s.selectStateInRangeFilteredStmt, selectStateInRangeFilteredSQL},
{&s.selectStateInRangeStmt, selectStateInRangeSQL},
{&s.updateEventJSONStmt, updateEventJSONSQL},
{&s.deleteEventsForRoomStmt, deleteEventsForRoomSQL},
@ -240,17 +252,28 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
ctx context.Context, txn *sql.Tx, r types.Range,
stateFilter *gomatrixserverlib.StateFilter, roomIDs []string,
) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt)
senders, notSenders := getSendersStateFilterFilter(stateFilter)
rows, err := stmt.QueryContext(
ctx, r.Low(), r.High(), pq.StringArray(roomIDs),
pq.StringArray(senders),
pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)),
stateFilter.ContainsURL,
stateFilter.Limit,
)
var rows *sql.Rows
var err error
if stateFilter != nil {
stmt := sqlutil.TxStmt(txn, s.selectStateInRangeFilteredStmt)
senders, notSenders := getSendersStateFilterFilter(stateFilter)
rows, err = stmt.QueryContext(
ctx, r.Low(), r.High(), pq.StringArray(roomIDs),
pq.StringArray(senders),
pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)),
stateFilter.ContainsURL,
stateFilter.Limit,
)
} else {
stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt)
rows, err = stmt.QueryContext(
ctx, r.Low(), r.High(), pq.StringArray(roomIDs),
r.High()-r.Low(),
)
}
if err != nil {
return nil, nil, err
}

View file

@ -5,10 +5,11 @@ import (
"database/sql"
"fmt"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
type DatabaseTransaction struct {
@ -277,6 +278,7 @@ func (d *DatabaseTransaction) GetBackwardTopologyPos(
// exclusive of oldPos, inclusive of newPos, for the rooms in which
// the user has new membership events.
// A list of joined room IDs is also returned in case the caller needs it.
// nolint:gocyclo
func (d *DatabaseTransaction) GetStateDeltas(
ctx context.Context, device *userapi.Device,
r types.Range, userID string,
@ -311,7 +313,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
}
// get all the state events ever (i.e. for all available rooms) between these two positions
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs)
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, nil, allRoomIDs)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
@ -326,6 +328,22 @@ func (d *DatabaseTransaction) GetStateDeltas(
return nil, nil, err
}
// get all the state events ever (i.e. for all available rooms) between these two positions
stateNeededFiltered, eventMapFiltered, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
}
return nil, nil, err
}
stateFiltered, err := d.fetchStateEvents(ctx, d.txn, stateNeededFiltered, eventMapFiltered)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
}
return nil, nil, err
}
// find out which rooms this user is peeking, if any.
// We do this before joins so any peeks get overwritten
peeks, err := d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, device.ID, r)
@ -371,6 +389,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
// If our membership is now join but the previous membership wasn't
// then this is a "join transition", so we'll insert this room.
if prevMembership != membership {
newlyJoinedRooms[roomID] = true
// Get the full room state, as we'll send that down for a newly
// joined room instead of a delta.
var s []types.StreamEvent
@ -383,8 +402,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
// Add the information for this room into the state so that
// it will get added with all of the rest of the joined rooms.
state[roomID] = s
newlyJoinedRooms[roomID] = true
stateFiltered[roomID] = s
}
// We won't add joined rooms into the delta at this point as they
@ -395,7 +413,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
deltas = append(deltas, types.StateDelta{
Membership: membership,
MembershipPos: ev.StreamPosition,
StateEvents: d.StreamEventsToEvents(device, stateStreamEvents),
StateEvents: d.StreamEventsToEvents(device, stateFiltered[roomID]),
RoomID: roomID,
})
break
@ -407,7 +425,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
for _, joinedRoomID := range joinedRoomIDs {
deltas = append(deltas, types.StateDelta{
Membership: gomatrixserverlib.Join,
StateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]),
StateEvents: d.StreamEventsToEvents(device, stateFiltered[joinedRoomID]),
RoomID: joinedRoomID,
NewlyJoined: newlyJoinedRooms[joinedRoomID],
})

View file

@ -29,8 +29,9 @@ import (
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
const outputRoomEventsSchema = `
@ -189,21 +190,36 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
for _, roomID := range roomIDs {
inputParams = append(inputParams, roomID)
}
stmt, params, err := prepareWithFilters(
s.db, txn, stmtSQL, inputParams,
stateFilter.Senders, stateFilter.NotSenders,
stateFilter.Types, stateFilter.NotTypes,
nil, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderAsc,
var (
stmt *sql.Stmt
params []any
err error
)
if stateFilter != nil {
stmt, params, err = prepareWithFilters(
s.db, txn, stmtSQL, inputParams,
stateFilter.Senders, stateFilter.NotSenders,
stateFilter.Types, stateFilter.NotTypes,
nil, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderAsc,
)
} else {
stmt, params, err = prepareWithFilters(
s.db, txn, stmtSQL, inputParams,
nil, nil,
nil, nil,
nil, nil, int(r.High()-r.Low()), FilterOrderAsc,
)
}
if err != nil {
return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err)
}
defer internal.CloseAndLogIfError(ctx, stmt, "selectStateInRange: stmt.close() failed")
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, nil, err
}
defer rows.Close() // nolint: errcheck
defer internal.CloseAndLogIfError(ctx, rows, "selectStateInRange: rows.close() failed")
// Fetch all the state change events for all rooms between the two positions then loop each event and:
// - Keep a cache of the event by ID (99% of state change events are for the event itself)
// - For each room ID, build up an array of event IDs which represents cumulative adds/removes
@ -269,6 +285,7 @@ func (s *outputRoomEventsStatements) SelectMaxEventID(
) (id int64, err error) {
var nullableID sql.NullInt64
stmt := sqlutil.TxStmt(txn, s.selectMaxEventIDStmt)
defer internal.CloseAndLogIfError(ctx, stmt, "SelectMaxEventID: stmt.close() failed")
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
if nullableID.Valid {
id = nullableID.Int64
@ -323,6 +340,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
return 0, err
}
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
defer internal.CloseAndLogIfError(ctx, insertStmt, "InsertEvent: stmt.close() failed")
_, err = insertStmt.ExecContext(
ctx,
streamPos,
@ -367,6 +385,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
if err != nil {
return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err)
}
defer internal.CloseAndLogIfError(ctx, stmt, "selectRecentEvents: stmt.close() failed")
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
@ -415,6 +434,8 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
if err != nil {
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
}
defer internal.CloseAndLogIfError(ctx, stmt, "SelectEarlyEvents: stmt.close() failed")
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, err
@ -456,6 +477,8 @@ func (s *outputRoomEventsStatements) SelectEvents(
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, stmt, "SelectEvents: stmt.close() failed")
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, err
@ -558,6 +581,10 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
filter.Types, filter.NotTypes,
nil, filter.ContainsURL, filter.Limit, FilterOrderDesc,
)
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, stmt, "SelectContextBeforeEvent: stmt.close() failed")
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
@ -596,6 +623,10 @@ func (s *outputRoomEventsStatements) SelectContextAfterEvent(
filter.Types, filter.NotTypes,
nil, filter.ContainsURL, filter.Limit, FilterOrderAsc,
)
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, stmt, "SelectContextAfterEvent: stmt.close() failed")
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {

View file

@ -0,0 +1,198 @@
package tables_test
import (
"context"
"database/sql"
"reflect"
"sort"
"testing"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test"
)
func newMembershipsTable(t *testing.T, dbType test.DBType) (tables.Memberships, *sql.DB, func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
if err != nil {
t.Fatalf("failed to open db: %s", err)
}
var tab tables.Memberships
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresMembershipsTable(db)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSqliteMembershipsTable(db)
}
if err != nil {
t.Fatalf("failed to make new table: %s", err)
}
return tab, db, close
}
func TestMembershipsTable(t *testing.T) {
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
// Create users
var userEvents []*gomatrixserverlib.HeaderedEvent
users := []string{alice.ID}
for _, x := range room.CurrentState() {
if x.StateKeyEquals(alice.ID) {
if _, err := x.Membership(); err == nil {
userEvents = append(userEvents, x)
break
}
}
}
if len(userEvents) == 0 {
t.Fatalf("didn't find creator membership event")
}
for i := 0; i < 10; i++ {
u := test.NewUser(t)
users = append(users, u.ID)
ev := room.CreateAndInsert(t, u, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "join",
}, test.WithStateKey(u.ID))
userEvents = append(userEvents, ev)
}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
table, _, close := newMembershipsTable(t, dbType)
defer close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
for _, ev := range userEvents {
if err := table.UpsertMembership(ctx, nil, ev, types.StreamPosition(ev.Depth()), 1); err != nil {
t.Fatalf("failed to upsert membership: %s", err)
}
}
testUpsert(t, ctx, table, userEvents[0], alice, room)
testMembershipCount(t, ctx, table, room)
testHeroes(t, ctx, table, alice, room, users)
})
}
func testHeroes(t *testing.T, ctx context.Context, table tables.Memberships, user *test.User, room *test.Room, users []string) {
// Re-slice and sort the expected users
users = users[1:]
sort.Strings(users)
type testCase struct {
name string
memberships []string
wantHeroes []string
}
testCases := []testCase{
{name: "no memberships queried", memberships: []string{}},
{name: "joined memberships queried should be limited", memberships: []string{gomatrixserverlib.Join}, wantHeroes: users[:5]},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, err := table.SelectHeroes(ctx, nil, room.ID, user.ID, tc.memberships)
if err != nil {
t.Fatalf("unable to select heroes: %s", err)
}
if gotLen := len(got); gotLen != len(tc.wantHeroes) {
t.Fatalf("expected %d heroes, got %d", len(tc.wantHeroes), gotLen)
}
if !reflect.DeepEqual(got, tc.wantHeroes) {
t.Fatalf("expected heroes to be %+v, got %+v", tc.wantHeroes, got)
}
})
}
}
func testMembershipCount(t *testing.T, ctx context.Context, table tables.Memberships, room *test.Room) {
t.Run("membership counts are correct", func(t *testing.T) {
// After 10 events, we should have 6 users (5 create related [incl. one member event], 5 member events = 6 users)
count, err := table.SelectMembershipCount(ctx, nil, room.ID, gomatrixserverlib.Join, 10)
if err != nil {
t.Fatalf("failed to get membership count: %s", err)
}
expectedCount := 6
if expectedCount != count {
t.Fatalf("expected member count to be %d, got %d", expectedCount, count)
}
// After 100 events, we should have all 11 users
count, err = table.SelectMembershipCount(ctx, nil, room.ID, gomatrixserverlib.Join, 100)
if err != nil {
t.Fatalf("failed to get membership count: %s", err)
}
expectedCount = 11
if expectedCount != count {
t.Fatalf("expected member count to be %d, got %d", expectedCount, count)
}
})
}
func testUpsert(t *testing.T, ctx context.Context, table tables.Memberships, membershipEvent *gomatrixserverlib.HeaderedEvent, user *test.User, room *test.Room) {
t.Run("upserting works as expected", func(t *testing.T) {
if err := table.UpsertMembership(ctx, nil, membershipEvent, 1, 1); err != nil {
t.Fatalf("failed to upsert membership: %s", err)
}
membership, pos, err := table.SelectMembershipForUser(ctx, nil, room.ID, user.ID, 1)
if err != nil {
t.Fatalf("failed to select membership: %s", err)
}
expectedPos := 1
if pos != expectedPos {
t.Fatalf("expected pos to be %d, got %d", expectedPos, pos)
}
if membership != gomatrixserverlib.Join {
t.Fatalf("expected membership to be join, got %s", membership)
}
// Create a new event which gets upserted and should not cause issues
ev := room.CreateAndInsert(t, user, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": gomatrixserverlib.Join,
}, test.WithStateKey(user.ID))
// Insert the same event again, but with different positions, which should get updated
if err = table.UpsertMembership(ctx, nil, ev, 2, 2); err != nil {
t.Fatalf("failed to upsert membership: %s", err)
}
// Verify the position got updated
membership, pos, err = table.SelectMembershipForUser(ctx, nil, room.ID, user.ID, 10)
if err != nil {
t.Fatalf("failed to select membership: %s", err)
}
expectedPos = 2
if pos != expectedPos {
t.Fatalf("expected pos to be %d, got %d", expectedPos, pos)
}
if membership != gomatrixserverlib.Join {
t.Fatalf("expected membership to be join, got %s", membership)
}
// If we can't find a membership, it should default to leave
if membership, _, err = table.SelectMembershipForUser(ctx, nil, room.ID, user.ID, 1); err != nil {
t.Fatalf("failed to select membership: %s", err)
}
if membership != gomatrixserverlib.Leave {
t.Fatalf("expected membership to be leave, got %s", membership)
}
})
}

View file

@ -74,21 +74,26 @@ func (p *InviteStreamProvider) IncrementalSync(
return to
}
for roomID := range retiredInvites {
if _, ok := req.Response.Rooms.Join[roomID]; !ok {
lr := types.NewLeaveResponse()
h := sha256.Sum256(append([]byte(roomID), []byte(strconv.FormatInt(int64(to), 10))...))
lr.Timeline.Events = append(lr.Timeline.Events, gomatrixserverlib.ClientEvent{
// fake event ID which muxes in the to position
EventID: "$" + base64.RawURLEncoding.EncodeToString(h[:]),
OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()),
RoomID: roomID,
Sender: req.Device.UserID,
StateKey: &req.Device.UserID,
Type: "m.room.member",
Content: gomatrixserverlib.RawJSON(`{"membership":"leave"}`),
})
req.Response.Rooms.Leave[roomID] = lr
if _, ok := req.Response.Rooms.Invite[roomID]; ok {
continue
}
if _, ok := req.Response.Rooms.Join[roomID]; ok {
continue
}
lr := types.NewLeaveResponse()
h := sha256.Sum256(append([]byte(roomID), []byte(strconv.FormatInt(int64(to), 10))...))
lr.Timeline.Events = append(lr.Timeline.Events, gomatrixserverlib.ClientEvent{
// fake event ID which muxes in the to position
EventID: "$" + base64.RawURLEncoding.EncodeToString(h[:]),
OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()),
RoomID: roomID,
Sender: req.Device.UserID,
StateKey: &req.Device.UserID,
Type: "m.room.member",
Content: gomatrixserverlib.RawJSON(`{"membership":"leave"}`),
})
req.Response.Rooms.Leave[roomID] = lr
}
return maxID

View file

@ -194,7 +194,7 @@ func (p *PDUStreamProvider) IncrementalSync(
}
}
var pos types.StreamPosition
if pos, err = p.addRoomDeltaToResponse(ctx, snapshot, req.Device, newRange, delta, &eventFilter, &stateFilter, req.Response); err != nil {
if pos, err = p.addRoomDeltaToResponse(ctx, snapshot, req.Device, newRange, delta, &eventFilter, &stateFilter, req); err != nil {
req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed")
if err == context.DeadlineExceeded || err == context.Canceled || err == sql.ErrTxDone {
return newPos
@ -225,7 +225,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
delta types.StateDelta,
eventFilter *gomatrixserverlib.RoomEventFilter,
stateFilter *gomatrixserverlib.StateFilter,
res *types.Response,
req *types.SyncRequest,
) (types.StreamPosition, error) {
if delta.MembershipPos > 0 && delta.Membership == gomatrixserverlib.Leave {
// make sure we don't leak recent events after the leave event.
@ -290,8 +290,10 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
hasMembershipChange := false
for _, recentEvent := range recentStreamEvents {
if recentEvent.Type() == gomatrixserverlib.MRoomMember && recentEvent.StateKey() != nil {
if membership, _ := recentEvent.Membership(); membership == gomatrixserverlib.Join {
req.MembershipChanges[*recentEvent.StateKey()] = struct{}{}
}
hasMembershipChange = true
break
}
}
@ -318,9 +320,9 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync)
// If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited.
jr.Timeline.Limited = limited && len(events) == len(recentEvents)
jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync)
res.Rooms.Join[delta.RoomID] = jr
req.Response.Rooms.Join[delta.RoomID] = jr
case gomatrixserverlib.Peek:
jr := types.NewJoinResponse()
@ -329,7 +331,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync)
res.Rooms.Peek[delta.RoomID] = jr
req.Response.Rooms.Peek[delta.RoomID] = jr
case gomatrixserverlib.Leave:
fallthrough // transitions to leave are the same as ban
@ -342,7 +344,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
// didn't "remove" events, return that the response is limited.
lr.Timeline.Limited = limited && len(events) == len(recentEvents)
lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync)
res.Rooms.Leave[delta.RoomID] = lr
req.Response.Rooms.Leave[delta.RoomID] = lr
}
return latestPosition, nil

View file

@ -121,7 +121,8 @@ func (p *PresenceStreamProvider) IncrementalSync(
prevPresence := pres.(*types.PresenceInternal)
currentlyActive := prevPresence.CurrentlyActive()
skip := prevPresence.Equals(presence) && currentlyActive && req.Device.UserID != presence.UserID
if skip {
_, membershipChange := req.MembershipChanges[presence.UserID]
if skip && !membershipChange {
req.Log.Tracef("Skipping presence, no change (%s)", presence.UserID)
continue
}

View file

@ -91,15 +91,16 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
})
return &types.SyncRequest{
Context: req.Context(), //
Log: logger, //
Device: &device, //
Response: types.NewResponse(), // Populated by all streams
Filter: filter, //
Since: since, //
Timeout: timeout, //
Rooms: make(map[string]string), // Populated by the PDU stream
WantFullState: wantFullState, //
Context: req.Context(), //
Log: logger, //
Device: &device, //
Response: types.NewResponse(), // Populated by all streams
Filter: filter, //
Since: since, //
Timeout: timeout, //
Rooms: make(map[string]string), // Populated by the PDU stream
WantFullState: wantFullState, //
MembershipChanges: make(map[string]struct{}), // Populated by the PDU stream
}, nil
}

View file

@ -4,9 +4,10 @@ import (
"context"
"time"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
userapi "github.com/matrix-org/dendrite/userapi/api"
)
type SyncRequest struct {
@ -22,6 +23,8 @@ type SyncRequest struct {
// Updated by the PDU stream.
Rooms map[string]string
// Updated by the PDU stream.
MembershipChanges map[string]struct{}
// Updated by the PDU stream.
IgnoredUsers IgnoredUsers
}

View file

@ -22,10 +22,6 @@ Forgotten room messages cannot be paginated
Local device key changes get to remote servers with correct prev_id
# Flakey
Local device key changes appear in /keys/changes
# we don't support groups
Remove group category
@ -39,12 +35,6 @@ Events in rooms with AS-hosted room aliases are sent to AS server
Inviting an AS-hosted user asks the AS server
Accesing an AS-hosted room alias asks the AS server
# Flakey, need additional investigation
Messages that notify from another user increment notification_count
Messages that highlight from another user increment unread highlight count
Notifications can be viewed with GET /notifications
# More flakey
Guest users can join guest_access rooms

View file

@ -746,4 +746,10 @@ Existing members see new member's presence
Inbound federation can return missing events for joined visibility
outliers whose auth_events are in a different room are correctly rejected
Messages that notify from another user increment notification_count
Messages that highlight from another user increment unread highlight count
Messages that highlight from another user increment unread highlight count
Newly joined room has correct timeline in incremental sync
When user joins a room the state is included in the next sync
When user joins a room the state is included in a gapped sync
Messages that notify from another user increment notification_count
Messages that highlight from another user increment unread highlight count
Notifications can be viewed with GET /notifications

View file

@ -96,7 +96,7 @@ type ClientUserAPI interface {
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
SetAvatarURL(ctx context.Context, req *PerformSetAvatarURLRequest, res *PerformSetAvatarURLResponse) error
SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *struct{}) error
SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *PerformUpdateDisplayNameResponse) error
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error
@ -596,7 +596,10 @@ type PerformSetAvatarURLRequest struct {
ServerName gomatrixserverlib.ServerName
AvatarURL string
}
type PerformSetAvatarURLResponse struct{}
type PerformSetAvatarURLResponse struct {
Profile *authtypes.Profile `json:"profile"`
Changed bool `json:"changed"`
}
type QueryNumericLocalpartResponse struct {
ID int64
@ -625,6 +628,11 @@ type PerformUpdateDisplayNameRequest struct {
DisplayName string
}
type PerformUpdateDisplayNameResponse struct {
Profile *authtypes.Profile `json:"profile"`
Changed bool `json:"changed"`
}
type QueryLocalpartForThreePIDRequest struct {
ThreePID, Medium string
}

View file

@ -168,7 +168,7 @@ func (t *UserInternalAPITrace) QueryAccountAvailability(ctx context.Context, req
return err
}
func (t *UserInternalAPITrace) SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *struct{}) error {
func (t *UserInternalAPITrace) SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *PerformUpdateDisplayNameResponse) error {
err := t.Impl.SetDisplayName(ctx, req, res)
util.GetLogger(ctx).Infof("SetDisplayName req=%+v res=%+v", js(req), js(res))
return err

View file

@ -81,7 +81,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats
readPos := msg.Header.Get(jetstream.EventID)
evType := msg.Header.Get("type")
if readPos == "" || evType != "m.read" {
if readPos == "" || (evType != "m.read" && evType != "m.read.private") {
return true
}

View file

@ -10,19 +10,24 @@ import (
"github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/userapi/storage"
)
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{
db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, "", 4, 0, 0, "")
if err != nil {
t.Fatalf("failed to create new user db: %v", err)
}
return db, close
return db, func() {
close()
baseclose()
}
}
func mustCreateEvent(t *testing.T, content string) *gomatrixserverlib.HeaderedEvent {

View file

@ -175,7 +175,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
return nil
}
if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
return err
}
@ -831,7 +831,10 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush
}
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error {
return a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL)
profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL)
res.Profile = profile
res.Changed = changed
return err
}
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error {
@ -865,8 +868,11 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q
}
}
func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, _ *struct{}) error {
return a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName)
func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error {
profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName)
res.Profile = profile
res.Changed = changed
return err
}
func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {

View file

@ -388,7 +388,7 @@ func (h *httpUserInternalAPI) QueryAccountByPassword(
func (h *httpUserInternalAPI) SetDisplayName(
ctx context.Context,
request *api.PerformUpdateDisplayNameRequest,
response *struct{},
response *api.PerformUpdateDisplayNameResponse,
) error {
return httputil.CallInternalRPCAPI(
"SetDisplayName", h.apiURL+PerformSetDisplayNamePath,

View file

@ -29,8 +29,8 @@ import (
type Profile interface {
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
SetDisplayName(ctx context.Context, localpart string, displayName string) error
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, bool, error)
SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, bool, error)
}
type Account interface {

View file

@ -26,7 +26,7 @@ import (
const accountDataSchema = `
-- Stores data about accounts data.
CREATE TABLE IF NOT EXISTS account_data (
CREATE TABLE IF NOT EXISTS userapi_account_datas (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL,
-- The room ID for this data (empty string if not specific to a room)
@ -41,15 +41,15 @@ CREATE TABLE IF NOT EXISTS account_data (
`
const insertAccountDataSQL = `
INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content
`
const selectAccountDataSQL = "" +
"SELECT room_id, type, content FROM account_data WHERE localpart = $1"
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1"
const selectAccountDataByTypeSQL = "" +
"SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3"
type accountDataStatements struct {
insertAccountDataStmt *sql.Stmt

View file

@ -32,7 +32,7 @@ import (
const accountsSchema = `
-- Stores data about accounts.
CREATE TABLE IF NOT EXISTS account_accounts (
CREATE TABLE IF NOT EXISTS userapi_accounts (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY,
-- When this account was first created, as a unix timestamp (ms resolution).
@ -51,22 +51,22 @@ CREATE TABLE IF NOT EXISTS account_accounts (
`
const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
"INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2"
const deactivateAccountSQL = "" +
"UPDATE account_accounts SET is_deactivated = TRUE WHERE localpart = $1"
"UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1"
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1"
"SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1"
const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
const selectNewNumericLocalpartSQL = "" +
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM account_accounts WHERE localpart ~ '^[0-9]{1,}$'"
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$'"
type accountsStatements struct {
insertAccountStmt *sql.Stmt

View file

@ -7,7 +7,7 @@ import (
)
func UpIsActive(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS is_deactivated BOOLEAN DEFAULT FALSE;")
_, err := tx.ExecContext(ctx, "ALTER TABLE userapi_accounts ADD COLUMN IF NOT EXISTS is_deactivated BOOLEAN DEFAULT FALSE;")
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
@ -15,7 +15,7 @@ func UpIsActive(ctx context.Context, tx *sql.Tx) error {
}
func DownIsActive(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts DROP COLUMN is_deactivated;")
_, err := tx.ExecContext(ctx, "ALTER TABLE userapi_accounts DROP COLUMN is_deactivated;")
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}

View file

@ -8,9 +8,9 @@ import (
func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS last_seen_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM CURRENT_TIMESTAMP)*1000;
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS ip TEXT;
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`)
ALTER TABLE userapi_devices ADD COLUMN IF NOT EXISTS last_seen_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM CURRENT_TIMESTAMP)*1000;
ALTER TABLE userapi_devices ADD COLUMN IF NOT EXISTS ip TEXT;
ALTER TABLE userapi_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
@ -19,9 +19,9 @@ ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`)
func DownLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
ALTER TABLE device_devices DROP COLUMN last_seen_ts;
ALTER TABLE device_devices DROP COLUMN ip;
ALTER TABLE device_devices DROP COLUMN user_agent;`)
ALTER TABLE userapi_devices DROP COLUMN last_seen_ts;
ALTER TABLE userapi_devices DROP COLUMN ip;
ALTER TABLE userapi_devices DROP COLUMN user_agent;`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}

View file

@ -9,10 +9,10 @@ import (
func UpAddAccountType(ctx context.Context, tx *sql.Tx) error {
// initially set every account to useraccount, change appservice and guest accounts afterwards
// (user = 1, guest = 2, admin = 3, appservice = 4)
_, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS account_type SMALLINT NOT NULL DEFAULT 1;
UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> '';
UPDATE account_accounts SET account_type = 2 WHERE localpart ~ '^[0-9]+$';
ALTER TABLE account_accounts ALTER COLUMN account_type DROP DEFAULT;`,
_, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts ADD COLUMN IF NOT EXISTS account_type SMALLINT NOT NULL DEFAULT 1;
UPDATE userapi_accounts SET account_type = 4 WHERE appservice_id <> '';
UPDATE userapi_accounts SET account_type = 2 WHERE localpart ~ '^[0-9]+$';
ALTER TABLE userapi_accounts ALTER COLUMN account_type DROP DEFAULT;`,
)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
@ -21,7 +21,7 @@ ALTER TABLE account_accounts ALTER COLUMN account_type DROP DEFAULT;`,
}
func DownAddAccountType(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts DROP COLUMN account_type;")
_, err := tx.ExecContext(ctx, "ALTER TABLE userapi_accounts DROP COLUMN account_type;")
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}

View file

@ -0,0 +1,102 @@
package deltas
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
)
var renameTableMappings = map[string]string{
"account_accounts": "userapi_accounts",
"account_data": "userapi_account_datas",
"device_devices": "userapi_devices",
"account_e2e_room_keys": "userapi_key_backups",
"account_e2e_room_keys_versions": "userapi_key_backup_versions",
"login_tokens": "userapi_login_tokens",
"open_id_tokens": "userapi_openid_tokens",
"account_profiles": "userapi_profiles",
"account_threepid": "userapi_threepids",
}
var renameSequenceMappings = map[string]string{
"device_session_id_seq": "userapi_device_session_id_seq",
"account_e2e_room_keys_versions_seq": "userapi_key_backup_versions_seq",
}
var renameIndicesMappings = map[string]string{
"device_localpart_id_idx": "userapi_device_localpart_id_idx",
"e2e_room_keys_idx": "userapi_key_backups_idx",
"e2e_room_keys_versions_idx": "userapi_key_backups_versions_idx",
"account_e2e_room_keys_versions_idx": "userapi_key_backup_versions_idx",
"login_tokens_expiration_idx": "userapi_login_tokens_expiration_idx",
"account_threepid_localpart": "userapi_threepid_idx",
}
// I know what you're thinking: you're wondering "why doesn't this use $1
// and pass variadic parameters to ExecContext?" — the answer is because
// PostgreSQL doesn't expect the table name to be specified as a substituted
// argument in that way so it results in a syntax error in the query.
func UpRenameTables(ctx context.Context, tx *sql.Tx) error {
for old, new := range renameTableMappings {
q := fmt.Sprintf(
"ALTER TABLE IF EXISTS %s RENAME TO %s;",
pq.QuoteIdentifier(old), pq.QuoteIdentifier(new),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("rename table %q to %q error: %w", old, new, err)
}
}
for old, new := range renameSequenceMappings {
q := fmt.Sprintf(
"ALTER SEQUENCE IF EXISTS %s RENAME TO %s;",
pq.QuoteIdentifier(old), pq.QuoteIdentifier(new),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("rename table %q to %q error: %w", old, new, err)
}
}
for old, new := range renameIndicesMappings {
q := fmt.Sprintf(
"ALTER INDEX IF EXISTS %s RENAME TO %s;",
pq.QuoteIdentifier(old), pq.QuoteIdentifier(new),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("rename table %q to %q error: %w", old, new, err)
}
}
return nil
}
func DownRenameTables(ctx context.Context, tx *sql.Tx) error {
for old, new := range renameTableMappings {
q := fmt.Sprintf(
"ALTER TABLE IF EXISTS %s RENAME TO %s;",
pq.QuoteIdentifier(new), pq.QuoteIdentifier(old),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("rename table %q to %q error: %w", new, old, err)
}
}
for old, new := range renameSequenceMappings {
q := fmt.Sprintf(
"ALTER SEQUENCE IF EXISTS %s RENAME TO %s;",
pq.QuoteIdentifier(new), pq.QuoteIdentifier(old),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("rename table %q to %q error: %w", new, old, err)
}
}
for old, new := range renameIndicesMappings {
q := fmt.Sprintf(
"ALTER INDEX IF EXISTS %s RENAME TO %s;",
pq.QuoteIdentifier(new), pq.QuoteIdentifier(old),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("rename table %q to %q error: %w", new, old, err)
}
}
return nil
}

View file

@ -31,10 +31,10 @@ import (
const devicesSchema = `
-- This sequence is used for automatic allocation of session_id.
CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
CREATE SEQUENCE IF NOT EXISTS userapi_device_session_id_seq START 1;
-- Stores data about devices.
CREATE TABLE IF NOT EXISTS device_devices (
CREATE TABLE IF NOT EXISTS userapi_devices (
-- The access token granted to this device. This has to be the primary key
-- so we can distinguish which device is making a given request.
access_token TEXT NOT NULL PRIMARY KEY,
@ -42,7 +42,7 @@ CREATE TABLE IF NOT EXISTS device_devices (
-- This can be used as a secure substitution of the access token in situations
-- where data is associated with access tokens (e.g. transaction storage),
-- so we don't have to store users' access tokens everywhere.
session_id BIGINT NOT NULL DEFAULT nextval('device_session_id_seq'),
session_id BIGINT NOT NULL DEFAULT nextval('userapi_device_session_id_seq'),
-- The device identifier. This only needs to uniquely identify a device for a given user, not globally.
-- access_tokens will be clobbered based on the device ID for a user.
device_id TEXT NOT NULL,
@ -65,39 +65,39 @@ CREATE TABLE IF NOT EXISTS device_devices (
);
-- Device IDs must be unique for a given user.
CREATE UNIQUE INDEX IF NOT EXISTS device_localpart_id_idx ON device_devices(localpart, device_id);
CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, device_id);
`
const insertDeviceSQL = "" +
"INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" +
"INSERT INTO userapi_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" +
" RETURNING session_id"
const selectDeviceByTokenSQL = "" +
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
"SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" +
"SELECT display_name, last_seen_ts, ip FROM device_devices WHERE localpart = $1 and device_id = $2"
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id = ANY($2)"
const selectDevicesByIDSQL = "" +
"SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
"SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
const updateDeviceLastSeen = "" +
"UPDATE device_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
type devicesStatements struct {
insertDeviceStmt *sql.Stmt

View file

@ -26,7 +26,7 @@ import (
)
const keyBackupTableSchema = `
CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
CREATE TABLE IF NOT EXISTS userapi_key_backups (
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
session_id TEXT NOT NULL,
@ -37,31 +37,31 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
is_verified BOOLEAN NOT NULL,
session_data TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version);
CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version);
CREATE UNIQUE INDEX IF NOT EXISTS userapi_key_backups_idx ON userapi_key_backups(user_id, room_id, session_id, version);
CREATE INDEX IF NOT EXISTS userapi_key_backups_versions_idx ON userapi_key_backups(user_id, version);
`
const insertBackupKeySQL = "" +
"INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " +
"INSERT INTO userapi_key_backups(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
const updateBackupKeySQL = "" +
"UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " +
"UPDATE userapi_key_backups SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " +
"WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8"
const countKeysSQL = "" +
"SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2"
"SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2"
const selectKeysSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2"
const selectKeysByRoomIDSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2 AND room_id = $3"
const selectKeysByRoomIDAndSessionIDSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4"
type keyBackupStatements struct {

View file

@ -26,40 +26,40 @@ import (
)
const keyBackupVersionTableSchema = `
CREATE SEQUENCE IF NOT EXISTS account_e2e_room_keys_versions_seq;
CREATE SEQUENCE IF NOT EXISTS userapi_key_backup_versions_seq;
-- the metadata for each generation of encrypted e2e session backups
CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions (
CREATE TABLE IF NOT EXISTS userapi_key_backup_versions (
user_id TEXT NOT NULL,
-- this means no 2 users will ever have the same version of e2e session backups which strictly
-- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1.
version BIGINT DEFAULT nextval('account_e2e_room_keys_versions_seq'),
version BIGINT DEFAULT nextval('userapi_key_backup_versions_seq'),
algorithm TEXT NOT NULL,
auth_data TEXT NOT NULL,
etag TEXT NOT NULL,
deleted SMALLINT DEFAULT 0 NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version);
CREATE UNIQUE INDEX IF NOT EXISTS userapi_key_backup_versions_idx ON userapi_key_backup_versions(user_id, version);
`
const insertKeyBackupSQL = "" +
"INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version"
"INSERT INTO userapi_key_backup_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version"
const updateKeyBackupAuthDataSQL = "" +
"UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3"
"UPDATE userapi_key_backup_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3"
const updateKeyBackupETagSQL = "" +
"UPDATE account_e2e_room_keys_versions SET etag = $1 WHERE user_id = $2 AND version = $3"
"UPDATE userapi_key_backup_versions SET etag = $1 WHERE user_id = $2 AND version = $3"
const deleteKeyBackupSQL = "" +
"UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2"
"UPDATE userapi_key_backup_versions SET deleted=1 WHERE user_id = $1 AND version = $2"
const selectKeyBackupSQL = "" +
"SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2"
"SELECT algorithm, auth_data, etag, deleted FROM userapi_key_backup_versions WHERE user_id = $1 AND version = $2"
const selectLatestVersionSQL = "" +
"SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1"
"SELECT MAX(version) FROM userapi_key_backup_versions WHERE user_id = $1"
type keyBackupVersionStatements struct {
insertKeyBackupStmt *sql.Stmt

View file

@ -26,7 +26,7 @@ import (
)
const loginTokenSchema = `
CREATE TABLE IF NOT EXISTS login_tokens (
CREATE TABLE IF NOT EXISTS userapi_login_tokens (
-- The random value of the token issued to a user
token TEXT NOT NULL PRIMARY KEY,
-- When the token expires
@ -37,17 +37,17 @@ CREATE TABLE IF NOT EXISTS login_tokens (
);
-- This index allows efficient garbage collection of expired tokens.
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at);
CREATE INDEX IF NOT EXISTS userapi_login_tokens_expiration_idx ON userapi_login_tokens(token_expires_at);
`
const insertLoginTokenSQL = "" +
"INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
"INSERT INTO userapi_login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
const deleteLoginTokenSQL = "" +
"DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"
"DELETE FROM userapi_login_tokens WHERE token = $1 OR token_expires_at <= $2"
const selectLoginTokenSQL = "" +
"SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"
"SELECT user_id FROM userapi_login_tokens WHERE token = $1 AND token_expires_at > $2"
type loginTokenStatements struct {
insertStmt *sql.Stmt
@ -78,7 +78,7 @@ func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx
// deleteByToken removes the named token.
//
// As a simple way to garbage-collect stale tokens, we also remove all expired tokens.
// The login_tokens_expiration_idx index should make that efficient.
// The userapi_login_tokens_expiration_idx index should make that efficient.
func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error {
stmt := sqlutil.TxStmt(txn, s.deleteStmt)
res, err := stmt.ExecContext(ctx, token, time.Now().UTC())

View file

@ -13,7 +13,7 @@ import (
const openIDTokenSchema = `
-- Stores data about openid tokens issued for accounts.
CREATE TABLE IF NOT EXISTS open_id_tokens (
CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
-- The value of the token issued to a user
token TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID for this account
@ -24,10 +24,10 @@ CREATE TABLE IF NOT EXISTS open_id_tokens (
`
const insertOpenIDTokenSQL = "" +
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
"INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
const selectOpenIDTokenSQL = "" +
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
"SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
type openIDTokenStatements struct {
insertTokenStmt *sql.Stmt

View file

@ -27,7 +27,7 @@ import (
const profilesSchema = `
-- Stores data about accounts profiles.
CREATE TABLE IF NOT EXISTS account_profiles (
CREATE TABLE IF NOT EXISTS userapi_profiles (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY,
-- The display name for this account
@ -38,19 +38,27 @@ CREATE TABLE IF NOT EXISTS account_profiles (
`
const insertProfileSQL = "" +
"INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
"INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
const selectProfileByLocalpartSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
const setAvatarURLSQL = "" +
"UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
"UPDATE userapi_profiles AS new" +
" SET avatar_url = $1" +
" FROM userapi_profiles AS old" +
" WHERE new.localpart = $2" +
" RETURNING new.display_name, old.avatar_url <> new.avatar_url"
const setDisplayNameSQL = "" +
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
"UPDATE userapi_profiles AS new" +
" SET display_name = $1" +
" FROM userapi_profiles AS old" +
" WHERE new.localpart = $2" +
" RETURNING new.avatar_url, old.display_name <> new.display_name"
const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
type profilesStatements struct {
serverNoticesLocalpart string
@ -100,16 +108,28 @@ func (s *profilesStatements) SelectProfileByLocalpart(
func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
) (err error) {
_, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart)
return
) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{
Localpart: localpart,
AvatarURL: avatarURL,
}
var changed bool
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName, &changed)
return profile, changed, err
}
func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
) (err error) {
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
return
) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{
Localpart: localpart,
DisplayName: displayName,
}
var changed bool
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL, &changed)
return profile, changed, err
}
func (s *profilesStatements) SelectProfilesBySearch(

View file

@ -45,7 +45,7 @@ CREATE INDEX IF NOT EXISTS userapi_daily_visits_localpart_timestamp_idx ON usera
const countUsersLastSeenAfterSQL = "" +
"SELECT COUNT(*) FROM (" +
" SELECT localpart FROM device_devices WHERE last_seen_ts > $1 " +
" SELECT localpart FROM userapi_devices WHERE last_seen_ts > $1 " +
" GROUP BY localpart" +
" ) u"
@ -62,7 +62,7 @@ R30Users counts the number of 30 day retained users, defined as:
const countR30UsersSQL = `
SELECT platform, COUNT(*) FROM (
SELECT users.localpart, platform, users.created_ts, MAX(uip.last_seen_ts)
FROM account_accounts users
FROM userapi_accounts users
INNER JOIN
(SELECT
localpart, last_seen_ts,
@ -75,7 +75,7 @@ SELECT platform, COUNT(*) FROM (
ELSE 'unknown'
END
AS platform
FROM device_devices
FROM userapi_devices
) uip
ON users.localpart = uip.localpart
AND users.account_type <> 4
@ -121,7 +121,7 @@ GROUP BY client_type
`
const countUserByAccountTypeSQL = `
SELECT COUNT(*) FROM account_accounts WHERE account_type = ANY($1)
SELECT COUNT(*) FROM userapi_accounts WHERE account_type = ANY($1)
`
// $1 = All non guest AccountType IDs
@ -134,7 +134,7 @@ SELECT user_type, COUNT(*) AS count FROM (
WHEN account_type = $2 AND appservice_id IS NULL THEN 'guest'
WHEN account_type = ANY($1) AND appservice_id IS NOT NULL THEN 'bridged'
END AS user_type
FROM account_accounts
FROM userapi_accounts
WHERE created_ts > $3
) AS t GROUP BY user_type
`
@ -143,14 +143,14 @@ SELECT user_type, COUNT(*) AS count FROM (
const updateUserDailyVisitsSQL = `
INSERT INTO userapi_daily_visits(localpart, device_id, timestamp, user_agent)
SELECT u.localpart, u.device_id, $1, MAX(u.user_agent)
FROM device_devices AS u
FROM userapi_devices AS u
LEFT JOIN (
SELECT localpart, device_id, timestamp FROM userapi_daily_visits
WHERE timestamp = $1
) udv
ON u.localpart = udv.localpart AND u.device_id = udv.device_id
INNER JOIN device_devices d ON d.localpart = u.localpart
INNER JOIN account_accounts a ON a.localpart = u.localpart
INNER JOIN userapi_devices d ON d.localpart = u.localpart
INNER JOIN userapi_accounts a ON a.localpart = u.localpart
WHERE $2 <= d.last_seen_ts AND d.last_seen_ts < $3
AND a.account_type in (1, 3)
GROUP BY u.localpart, u.device_id

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/userapi/storage/shared"
// Import the postgres database driver.
@ -36,6 +37,16 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
return nil, err
}
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "userapi: rename tables",
Up: deltas.UpRenameTables,
Down: deltas.DownRenameTables,
})
if err = m.Up(base.Context()); err != nil {
return nil, err
}
accountDataTable, err := NewPostgresAccountDataTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)

View file

@ -26,7 +26,7 @@ import (
const threepidSchema = `
-- Stores data about third party identifiers
CREATE TABLE IF NOT EXISTS account_threepid (
CREATE TABLE IF NOT EXISTS userapi_threepids (
-- The third party identifier
threepid TEXT NOT NULL,
-- The 3PID medium
@ -37,20 +37,20 @@ CREATE TABLE IF NOT EXISTS account_threepid (
PRIMARY KEY(threepid, medium)
);
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart);
CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart);
`
const selectLocalpartForThreePIDSQL = "" +
"SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
"SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
const selectThreePIDsForLocalpartSQL = "" +
"SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1"
const insertThreePIDSQL = "" +
"INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)"
"INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)"
const deleteThreePIDSQL = "" +
"DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
"DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
type threepidStatements struct {
selectLocalpartForThreePIDStmt *sql.Stmt

View file

@ -96,20 +96,24 @@ func (d *Database) GetProfileByLocalpart(
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
) (profile *authtypes.Profile, changed bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
return err
})
return
}
// SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
) (profile *authtypes.Profile, changed bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
return err
})
return
}
// SetPassword sets the account password to the given hash.

View file

@ -25,7 +25,7 @@ import (
const accountDataSchema = `
-- Stores data about accounts data.
CREATE TABLE IF NOT EXISTS account_data (
CREATE TABLE IF NOT EXISTS userapi_account_datas (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL,
-- The room ID for this data (empty string if not specific to a room)
@ -40,15 +40,15 @@ CREATE TABLE IF NOT EXISTS account_data (
`
const insertAccountDataSQL = `
INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
`
const selectAccountDataSQL = "" +
"SELECT room_id, type, content FROM account_data WHERE localpart = $1"
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1"
const selectAccountDataByTypeSQL = "" +
"SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3"
type accountDataStatements struct {
db *sql.DB

View file

@ -32,7 +32,7 @@ import (
const accountsSchema = `
-- Stores data about accounts.
CREATE TABLE IF NOT EXISTS account_accounts (
CREATE TABLE IF NOT EXISTS userapi_accounts (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY,
-- When this account was first created, as a unix timestamp (ms resolution).
@ -51,22 +51,22 @@ CREATE TABLE IF NOT EXISTS account_accounts (
`
const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
"INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2"
const deactivateAccountSQL = "" +
"UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1"
"UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1"
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1"
"SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1"
const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = 0"
const selectNewNumericLocalpartSQL = "" +
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM account_accounts WHERE CAST(localpart AS INT) <> 0"
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0"
type accountsStatements struct {
db *sql.DB

View file

@ -8,8 +8,8 @@ import (
func UpIsActive(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
ALTER TABLE account_accounts RENAME TO account_accounts_tmp;
CREATE TABLE account_accounts (
ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
CREATE TABLE userapi_accounts (
localpart TEXT NOT NULL PRIMARY KEY,
created_ts BIGINT NOT NULL,
password_hash TEXT,
@ -17,13 +17,13 @@ CREATE TABLE account_accounts (
is_deactivated BOOLEAN DEFAULT 0
);
INSERT
INTO account_accounts (
INTO userapi_accounts (
localpart, created_ts, password_hash, appservice_id
) SELECT
localpart, created_ts, password_hash, appservice_id
FROM account_accounts_tmp
FROM userapi_accounts_tmp
;
DROP TABLE account_accounts_tmp;`)
DROP TABLE userapi_accounts_tmp;`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
@ -32,21 +32,21 @@ DROP TABLE account_accounts_tmp;`)
func DownIsActive(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
ALTER TABLE account_accounts RENAME TO account_accounts_tmp;
CREATE TABLE account_accounts (
ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
CREATE TABLE userapi_accounts (
localpart TEXT NOT NULL PRIMARY KEY,
created_ts BIGINT NOT NULL,
password_hash TEXT,
appservice_id TEXT
);
INSERT
INTO account_accounts (
INTO userapi_accounts (
localpart, created_ts, password_hash, appservice_id
) SELECT
localpart, created_ts, password_hash, appservice_id
FROM account_accounts_tmp
FROM userapi_accounts_tmp
;
DROP TABLE account_accounts_tmp;`)
DROP TABLE userapi_accounts_tmp;`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}

View file

@ -8,8 +8,8 @@ import (
func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
ALTER TABLE device_devices RENAME TO device_devices_tmp;
CREATE TABLE device_devices (
ALTER TABLE userapi_devices RENAME TO userapi_devices_tmp;
CREATE TABLE userapi_devices (
access_token TEXT PRIMARY KEY,
session_id INTEGER,
device_id TEXT ,
@ -22,12 +22,12 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
UNIQUE (localpart, device_id)
);
INSERT
INTO device_devices (
INTO userapi_devices (
access_token, session_id, device_id, localpart, created_ts, display_name, last_seen_ts, ip, user_agent
) SELECT
access_token, session_id, device_id, localpart, created_ts, display_name, created_ts, '', ''
FROM device_devices_tmp;
DROP TABLE device_devices_tmp;`)
FROM userapi_devices_tmp;
DROP TABLE userapi_devices_tmp;`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
@ -36,8 +36,8 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
func DownLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
ALTER TABLE device_devices RENAME TO device_devices_tmp;
CREATE TABLE IF NOT EXISTS device_devices (
ALTER TABLE userapi_devices RENAME TO userapi_devices_tmp;
CREATE TABLE IF NOT EXISTS userapi_devices (
access_token TEXT PRIMARY KEY,
session_id INTEGER,
device_id TEXT ,
@ -47,12 +47,12 @@ CREATE TABLE IF NOT EXISTS device_devices (
UNIQUE (localpart, device_id)
);
INSERT
INTO device_devices (
INTO userapi_devices (
access_token, session_id, device_id, localpart, created_ts, display_name
) SELECT
access_token, session_id, device_id, localpart, created_ts, display_name
FROM device_devices_tmp;
DROP TABLE device_devices_tmp;`)
FROM userapi_devices_tmp;
DROP TABLE userapi_devices_tmp;`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}

View file

@ -9,8 +9,8 @@ import (
func UpAddAccountType(ctx context.Context, tx *sql.Tx) error {
// initially set every account to useraccount, change appservice and guest accounts afterwards
// (user = 1, guest = 2, admin = 3, appservice = 4)
_, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts RENAME TO account_accounts_tmp;
CREATE TABLE account_accounts (
_, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
CREATE TABLE userapi_accounts (
localpart TEXT NOT NULL PRIMARY KEY,
created_ts BIGINT NOT NULL,
password_hash TEXT,
@ -19,15 +19,15 @@ CREATE TABLE account_accounts (
account_type INTEGER NOT NULL
);
INSERT
INTO account_accounts (
INTO userapi_accounts (
localpart, created_ts, password_hash, appservice_id, account_type
) SELECT
localpart, created_ts, password_hash, appservice_id, 1
FROM account_accounts_tmp
FROM userapi_accounts_tmp
;
UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> '';
UPDATE account_accounts SET account_type = 2 WHERE localpart GLOB '[0-9]*';
DROP TABLE account_accounts_tmp;`)
UPDATE userapi_accounts SET account_type = 4 WHERE appservice_id <> '';
UPDATE userapi_accounts SET account_type = 2 WHERE localpart GLOB '[0-9]*';
DROP TABLE userapi_accounts_tmp;`)
if err != nil {
return fmt.Errorf("failed to add column: %w", err)
}
@ -35,7 +35,7 @@ DROP TABLE account_accounts_tmp;`)
}
func DownAddAccountType(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts DROP COLUMN account_type;`)
_, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts DROP COLUMN account_type;`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}

View file

@ -0,0 +1,109 @@
package deltas
import (
"context"
"database/sql"
"fmt"
"strings"
)
var renameTableMappings = map[string]string{
"account_accounts": "userapi_accounts",
"account_data": "userapi_account_datas",
"device_devices": "userapi_devices",
"account_e2e_room_keys": "userapi_key_backups",
"account_e2e_room_keys_versions": "userapi_key_backup_versions",
"login_tokens": "userapi_login_tokens",
"open_id_tokens": "userapi_openid_tokens",
"account_profiles": "userapi_profiles",
"account_threepid": "userapi_threepids",
}
var renameIndicesMappings = map[string]string{
"device_localpart_id_idx": "userapi_device_localpart_id_idx",
"e2e_room_keys_idx": "userapi_key_backups_idx",
"e2e_room_keys_versions_idx": "userapi_key_backups_versions_idx",
"account_e2e_room_keys_versions_idx": "userapi_key_backup_versions_idx",
"login_tokens_expiration_idx": "userapi_login_tokens_expiration_idx",
"account_threepid_localpart": "userapi_threepid_idx",
}
func UpRenameTables(ctx context.Context, tx *sql.Tx) error {
for old, new := range renameTableMappings {
// SQLite has no "IF EXISTS" so check if the table exists.
var name string
if err := tx.QueryRowContext(
ctx, "SELECT name FROM sqlite_schema WHERE type = 'table' AND name = $1;", old,
).Scan(&name); err != nil {
if err == sql.ErrNoRows {
continue
}
return err
}
q := fmt.Sprintf(
"ALTER TABLE %s RENAME TO %s;", old, new,
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("rename table %q to %q error: %w", old, new, err)
}
}
for old, new := range renameIndicesMappings {
var query string
if err := tx.QueryRowContext(
ctx, "SELECT sql FROM sqlite_schema WHERE type = 'index' AND name = $1;", old,
).Scan(&query); err != nil {
if err == sql.ErrNoRows {
continue
}
return err
}
query = strings.Replace(query, old, new, 1)
if _, err := tx.ExecContext(ctx, fmt.Sprintf("DROP INDEX %s;", old)); err != nil {
return fmt.Errorf("drop index %q to %q error: %w", old, new, err)
}
if _, err := tx.ExecContext(ctx, query); err != nil {
return fmt.Errorf("recreate index %q to %q error: %w", old, new, err)
}
}
return nil
}
func DownRenameTables(ctx context.Context, tx *sql.Tx) error {
for old, new := range renameTableMappings {
// SQLite has no "IF EXISTS" so check if the table exists.
var name string
if err := tx.QueryRowContext(
ctx, "SELECT name FROM sqlite_schema WHERE type = 'table' AND name = $1;", new,
).Scan(&name); err != nil {
if err == sql.ErrNoRows {
continue
}
return err
}
q := fmt.Sprintf(
"ALTER TABLE %s RENAME TO %s;", new, old,
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("rename table %q to %q error: %w", new, old, err)
}
}
for old, new := range renameIndicesMappings {
var query string
if err := tx.QueryRowContext(
ctx, "SELECT sql FROM sqlite_schema WHERE type = 'index' AND name = $1;", new,
).Scan(&query); err != nil {
if err == sql.ErrNoRows {
continue
}
return err
}
query = strings.Replace(query, new, old, 1)
if _, err := tx.ExecContext(ctx, fmt.Sprintf("DROP INDEX %s;", new)); err != nil {
return fmt.Errorf("drop index %q to %q error: %w", new, old, err)
}
if _, err := tx.ExecContext(ctx, query); err != nil {
return fmt.Errorf("recreate index %q to %q error: %w", new, old, err)
}
}
return nil
}

View file

@ -35,7 +35,7 @@ const devicesSchema = `
-- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
-- Stores data about devices.
CREATE TABLE IF NOT EXISTS device_devices (
CREATE TABLE IF NOT EXISTS userapi_devices (
access_token TEXT PRIMARY KEY,
session_id INTEGER,
device_id TEXT ,
@ -51,38 +51,38 @@ CREATE TABLE IF NOT EXISTS device_devices (
`
const insertDeviceSQL = "" +
"INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
"INSERT INTO userapi_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
const selectDevicesCountSQL = "" +
"SELECT COUNT(access_token) FROM device_devices"
"SELECT COUNT(access_token) FROM userapi_devices"
const selectDeviceByTokenSQL = "" +
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
"SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" +
"SELECT display_name, last_seen_ts, ip FROM device_devices WHERE localpart = $1 and device_id = $2"
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id IN ($2)"
const selectDevicesByIDSQL = "" +
"SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
"SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
const updateDeviceLastSeen = "" +
"UPDATE device_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
type devicesStatements struct {
db *sql.DB

View file

@ -26,7 +26,7 @@ import (
)
const keyBackupTableSchema = `
CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
CREATE TABLE IF NOT EXISTS userapi_key_backups (
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
session_id TEXT NOT NULL,
@ -37,31 +37,31 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
is_verified BOOLEAN NOT NULL,
session_data TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version);
CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version);
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON userapi_key_backups(user_id, room_id, session_id, version);
CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON userapi_key_backups(user_id, version);
`
const insertBackupKeySQL = "" +
"INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " +
"INSERT INTO userapi_key_backups(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
const updateBackupKeySQL = "" +
"UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " +
"UPDATE userapi_key_backups SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " +
"WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8"
const countKeysSQL = "" +
"SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2"
"SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2"
const selectKeysSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2"
const selectKeysByRoomIDSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2 AND room_id = $3"
const selectKeysByRoomIDAndSessionIDSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4"
type keyBackupStatements struct {

View file

@ -27,7 +27,7 @@ import (
const keyBackupVersionTableSchema = `
-- the metadata for each generation of encrypted e2e session backups
CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions (
CREATE TABLE IF NOT EXISTS userapi_key_backup_versions (
user_id TEXT NOT NULL,
-- this means no 2 users will ever have the same version of e2e session backups which strictly
-- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1.
@ -38,26 +38,26 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions (
deleted INTEGER DEFAULT 0 NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version);
CREATE UNIQUE INDEX IF NOT EXISTS userapi_key_backup_versions_idx ON userapi_key_backup_versions(user_id, version);
`
const insertKeyBackupSQL = "" +
"INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version"
"INSERT INTO userapi_key_backup_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version"
const updateKeyBackupAuthDataSQL = "" +
"UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3"
"UPDATE userapi_key_backup_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3"
const updateKeyBackupETagSQL = "" +
"UPDATE account_e2e_room_keys_versions SET etag = $1 WHERE user_id = $2 AND version = $3"
"UPDATE userapi_key_backup_versions SET etag = $1 WHERE user_id = $2 AND version = $3"
const deleteKeyBackupSQL = "" +
"UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2"
"UPDATE userapi_key_backup_versions SET deleted=1 WHERE user_id = $1 AND version = $2"
const selectKeyBackupSQL = "" +
"SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2"
"SELECT algorithm, auth_data, etag, deleted FROM userapi_key_backup_versions WHERE user_id = $1 AND version = $2"
const selectLatestVersionSQL = "" +
"SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1"
"SELECT MAX(version) FROM userapi_key_backup_versions WHERE user_id = $1"
type keyBackupVersionStatements struct {
insertKeyBackupStmt *sql.Stmt

View file

@ -32,7 +32,7 @@ type loginTokenStatements struct {
}
const loginTokenSchema = `
CREATE TABLE IF NOT EXISTS login_tokens (
CREATE TABLE IF NOT EXISTS userapi_login_tokens (
-- The random value of the token issued to a user
token TEXT NOT NULL PRIMARY KEY,
-- When the token expires
@ -43,17 +43,17 @@ CREATE TABLE IF NOT EXISTS login_tokens (
);
-- This index allows efficient garbage collection of expired tokens.
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at);
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON userapi_login_tokens(token_expires_at);
`
const insertLoginTokenSQL = "" +
"INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
"INSERT INTO userapi_login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
const deleteLoginTokenSQL = "" +
"DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"
"DELETE FROM userapi_login_tokens WHERE token = $1 OR token_expires_at <= $2"
const selectLoginTokenSQL = "" +
"SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"
"SELECT user_id FROM userapi_login_tokens WHERE token = $1 AND token_expires_at > $2"
func NewSQLiteLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) {
s := &loginTokenStatements{}
@ -78,7 +78,7 @@ func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx
// deleteByToken removes the named token.
//
// As a simple way to garbage-collect stale tokens, we also remove all expired tokens.
// The login_tokens_expiration_idx index should make that efficient.
// The userapi_login_tokens_expiration_idx index should make that efficient.
func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error {
stmt := sqlutil.TxStmt(txn, s.deleteStmt)
res, err := stmt.ExecContext(ctx, token, time.Now().UTC())

View file

@ -13,7 +13,7 @@ import (
const openIDTokenSchema = `
-- Stores data about accounts.
CREATE TABLE IF NOT EXISTS open_id_tokens (
CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
-- The value of the token issued to a user
token TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID for this account
@ -24,10 +24,10 @@ CREATE TABLE IF NOT EXISTS open_id_tokens (
`
const insertOpenIDTokenSQL = "" +
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
"INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
const selectOpenIDTokenSQL = "" +
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
"SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
type openIDTokenStatements struct {
db *sql.DB

View file

@ -27,7 +27,7 @@ import (
const profilesSchema = `
-- Stores data about accounts profiles.
CREATE TABLE IF NOT EXISTS account_profiles (
CREATE TABLE IF NOT EXISTS userapi_profiles (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY,
-- The display name for this account
@ -38,19 +38,21 @@ CREATE TABLE IF NOT EXISTS account_profiles (
`
const insertProfileSQL = "" +
"INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
"INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
const selectProfileByLocalpartSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
const setAvatarURLSQL = "" +
"UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" +
" RETURNING display_name"
const setDisplayNameSQL = "" +
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" +
" RETURNING avatar_url"
const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
type profilesStatements struct {
db *sql.DB
@ -102,18 +104,40 @@ func (s *profilesStatements) SelectProfileByLocalpart(
func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
) (err error) {
) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{
Localpart: localpart,
AvatarURL: avatarURL,
}
old, err := s.SelectProfileByLocalpart(ctx, localpart)
if err != nil {
return old, false, err
}
if old.AvatarURL == avatarURL {
return old, false, nil
}
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
_, err = stmt.ExecContext(ctx, avatarURL, localpart)
return
err = stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName)
return profile, true, err
}
func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
) (err error) {
) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{
Localpart: localpart,
DisplayName: displayName,
}
old, err := s.SelectProfileByLocalpart(ctx, localpart)
if err != nil {
return old, false, err
}
if old.DisplayName == displayName {
return old, false, nil
}
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
_, err = stmt.ExecContext(ctx, displayName, localpart)
return
err = stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL)
return profile, true, err
}
func (s *profilesStatements) SelectProfilesBySearch(

View file

@ -46,7 +46,7 @@ CREATE INDEX IF NOT EXISTS userapi_daily_visits_localpart_timestamp_idx ON usera
const countUsersLastSeenAfterSQL = "" +
"SELECT COUNT(*) FROM (" +
" SELECT localpart FROM device_devices WHERE last_seen_ts > $1 " +
" SELECT localpart FROM userapi_devices WHERE last_seen_ts > $1 " +
" GROUP BY localpart" +
" ) u"
@ -63,7 +63,7 @@ R30Users counts the number of 30 day retained users, defined as:
const countR30UsersSQL = `
SELECT platform, COUNT(*) FROM (
SELECT users.localpart, platform, users.created_ts, MAX(uip.last_seen_ts)
FROM account_accounts users
FROM userapi_accounts users
INNER JOIN
(SELECT
localpart, last_seen_ts,
@ -76,7 +76,7 @@ SELECT platform, COUNT(*) FROM (
ELSE 'unknown'
END
AS platform
FROM device_devices
FROM userapi_devices
) uip
ON users.localpart = uip.localpart
AND users.account_type <> 4
@ -126,7 +126,7 @@ GROUP BY client_type
`
const countUserByAccountTypeSQL = `
SELECT COUNT(*) FROM account_accounts WHERE account_type IN ($1)
SELECT COUNT(*) FROM userapi_accounts WHERE account_type IN ($1)
`
// $1 = Guest AccountType
@ -139,7 +139,7 @@ SELECT user_type, COUNT(*) AS count FROM (
WHEN account_type = $4 AND appservice_id IS NULL THEN 'guest'
WHEN account_type IN ($5) AND appservice_id IS NOT NULL THEN 'bridged'
END AS user_type
FROM account_accounts
FROM userapi_accounts
WHERE created_ts > $8
) AS t GROUP BY user_type
`
@ -148,14 +148,14 @@ SELECT user_type, COUNT(*) AS count FROM (
const updateUserDailyVisitsSQL = `
INSERT INTO userapi_daily_visits(localpart, device_id, timestamp, user_agent)
SELECT u.localpart, u.device_id, $1, MAX(u.user_agent)
FROM device_devices AS u
FROM userapi_devices AS u
LEFT JOIN (
SELECT localpart, device_id, timestamp FROM userapi_daily_visits
WHERE timestamp = $1
) udv
ON u.localpart = udv.localpart AND u.device_id = udv.device_id
INNER JOIN device_devices d ON d.localpart = u.localpart
INNER JOIN account_accounts a ON a.localpart = u.localpart
INNER JOIN userapi_devices d ON d.localpart = u.localpart
INNER JOIN userapi_accounts a ON a.localpart = u.localpart
WHERE $2 <= d.last_seen_ts AND d.last_seen_ts < $3
AND a.account_type in (1, 3)
GROUP BY u.localpart, u.device_id

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/shared"
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
)
// NewDatabase creates a new accounts and profiles database
@ -34,6 +35,16 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
return nil, err
}
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "userapi: rename tables",
Up: deltas.UpRenameTables,
Down: deltas.DownRenameTables,
})
if err = m.Up(base.Context()); err != nil {
return nil, err
}
accountDataTable, err := NewSQLiteAccountDataTable(db)
if err != nil {
return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)

View file

@ -27,7 +27,7 @@ import (
const threepidSchema = `
-- Stores data about third party identifiers
CREATE TABLE IF NOT EXISTS account_threepid (
CREATE TABLE IF NOT EXISTS userapi_threepids (
-- The third party identifier
threepid TEXT NOT NULL,
-- The 3PID medium
@ -38,20 +38,20 @@ CREATE TABLE IF NOT EXISTS account_threepid (
PRIMARY KEY(threepid, medium)
);
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart);
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart);
`
const selectLocalpartForThreePIDSQL = "" +
"SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
"SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
const selectThreePIDsForLocalpartSQL = "" +
"SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1"
const insertThreePIDSQL = "" +
"INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)"
"INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)"
const deleteThreePIDSQL = "" +
"DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
"DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
type threepidStatements struct {
db *sql.DB

View file

@ -16,6 +16,7 @@ import (
"github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/tables"
@ -29,14 +30,18 @@ var (
)
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{
db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server")
if err != nil {
t.Fatalf("NewUserAPIDatabase returned %s", err)
}
return db, close
return db, func() {
close()
baseclose()
}
}
// Tests storing and getting account data
@ -377,15 +382,23 @@ func Test_Profile(t *testing.T) {
// set avatar & displayname
wantProfile.DisplayName = "Alice"
wantProfile.AvatarURL = "mxc://aliceAvatar"
err = db.SetDisplayName(ctx, aliceLocalpart, "Alice")
assert.NoError(t, err, "unable to set displayname")
err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
assert.NoError(t, err, "unable to set avatar url")
// verify profile
gotProfile, err = db.GetProfileByLocalpart(ctx, aliceLocalpart)
assert.NoError(t, err, "unable to get profile by localpart")
gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, "Alice")
assert.Equal(t, wantProfile, gotProfile)
assert.NoError(t, err, "unable to set displayname")
assert.True(t, changed)
wantProfile.AvatarURL = "mxc://aliceAvatar"
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
assert.NoError(t, err, "unable to set avatar url")
assert.Equal(t, wantProfile, gotProfile)
assert.True(t, changed)
// Setting the same avatar again doesn't change anything
wantProfile.AvatarURL = "mxc://aliceAvatar"
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
assert.NoError(t, err, "unable to set avatar url")
assert.Equal(t, wantProfile, gotProfile)
assert.False(t, changed)
// search profiles
searchRes, err := db.SearchProfiles(ctx, "Alice", 2)

View file

@ -84,8 +84,8 @@ type OpenIDTable interface {
type ProfileTable interface {
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error
SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (err error)
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (err error)
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, bool, error)
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, bool, error)
SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
}

View file

@ -106,7 +106,7 @@ func mustUpdateDeviceLastSeen(
timestamp time.Time,
) {
t.Helper()
_, err := db.ExecContext(ctx, "UPDATE device_devices SET last_seen_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart)
_, err := db.ExecContext(ctx, "UPDATE userapi_devices SET last_seen_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart)
if err != nil {
t.Fatalf("unable to update device last seen")
}
@ -119,7 +119,7 @@ func mustUserUpdateRegistered(
localpart string,
timestamp time.Time,
) {
_, err := db.ExecContext(ctx, "UPDATE account_accounts SET created_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart)
_, err := db.ExecContext(ctx, "UPDATE userapi_accounts SET created_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart)
if err != nil {
t.Fatalf("unable to update device last seen")
}

View file

@ -23,13 +23,15 @@ import (
"time"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/dendrite/userapi/inthttp"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/dendrite/userapi/inthttp"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/internal"
@ -48,9 +50,9 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap
if opts.loginTokenLifetime == 0 {
opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond
}
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
connStr, close := test.PrepareDBConnectionString(t, dbType)
accountDB, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{
accountDB, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "")
if err != nil {
@ -64,9 +66,12 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap
}
return &internal.UserInternalAPI{
DB: accountDB,
Config: cfg,
}, accountDB, close
DB: accountDB,
Config: cfg,
}, accountDB, func() {
close()
baseclose()
}
}
func TestQueryProfile(t *testing.T) {
@ -79,10 +84,10 @@ func TestQueryProfile(t *testing.T) {
if err != nil {
t.Fatalf("failed to make account: %s", err)
}
if err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil {
if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil {
t.Fatalf("failed to set avatar url: %s", err)
}
if err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil {
if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil {
t.Fatalf("failed to set display name: %s", err)
}