mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-16 18:43:10 -06:00
Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/statedelta
This commit is contained in:
commit
9cc509c3f4
|
|
@ -19,6 +19,9 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
|
|
@ -26,8 +29,6 @@ import (
|
|||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
type redactionContent struct {
|
||||
|
|
@ -51,7 +52,7 @@ func SendRedaction(
|
|||
|
||||
if txnID != nil {
|
||||
// Try to fetch response from transactionsCache
|
||||
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok {
|
||||
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok {
|
||||
return *res
|
||||
}
|
||||
}
|
||||
|
|
@ -144,7 +145,7 @@ func SendRedaction(
|
|||
|
||||
// Add response to transactionsCache
|
||||
if txnID != nil {
|
||||
txnCache.AddTransaction(device.AccessToken, *txnID, &res)
|
||||
txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res)
|
||||
}
|
||||
|
||||
return res
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ func SendEvent(
|
|||
|
||||
if txnID != nil {
|
||||
// Try to fetch response from transactionsCache
|
||||
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok {
|
||||
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok {
|
||||
return *res
|
||||
}
|
||||
}
|
||||
|
|
@ -206,7 +206,7 @@ func SendEvent(
|
|||
}
|
||||
// Add response to transactionsCache
|
||||
if txnID != nil {
|
||||
txnCache.AddTransaction(device.AccessToken, *txnID, &res)
|
||||
txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res)
|
||||
}
|
||||
|
||||
// Take a note of how long it took to generate the event vs submit
|
||||
|
|
|
|||
|
|
@ -16,12 +16,13 @@ import (
|
|||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/matrix-org/util"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/dendrite/clientapi/producers"
|
||||
"github.com/matrix-org/dendrite/internal/transactions"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
// SendToDevice handles PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId}
|
||||
|
|
@ -33,7 +34,7 @@ func SendToDevice(
|
|||
eventType string, txnID *string,
|
||||
) util.JSONResponse {
|
||||
if txnID != nil {
|
||||
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok {
|
||||
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok {
|
||||
return *res
|
||||
}
|
||||
}
|
||||
|
|
@ -63,7 +64,7 @@ func SendToDevice(
|
|||
}
|
||||
|
||||
if txnID != nil {
|
||||
txnCache.AddTransaction(device.AccessToken, *txnID, &res)
|
||||
txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res)
|
||||
}
|
||||
|
||||
return res
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/version"
|
||||
"github.com/matrix-org/gomatrix"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/tokens"
|
||||
|
|
@ -29,6 +28,8 @@ import (
|
|||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/version"
|
||||
|
||||
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
|
||||
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
|
|
@ -73,7 +74,7 @@ func SendServerNotice(
|
|||
|
||||
if txnID != nil {
|
||||
// Try to fetch response from transactionsCache
|
||||
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok {
|
||||
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok {
|
||||
return *res
|
||||
}
|
||||
}
|
||||
|
|
@ -251,7 +252,7 @@ func SendServerNotice(
|
|||
}
|
||||
// Add response to transactionsCache
|
||||
if txnID != nil {
|
||||
txnCache.AddTransaction(device.AccessToken, *txnID, &res)
|
||||
txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res)
|
||||
}
|
||||
|
||||
// Take a note of how long it took to generate the event vs submit
|
||||
|
|
|
|||
|
|
@ -231,9 +231,9 @@ GEM
|
|||
jekyll-seo-tag (~> 2.1)
|
||||
minitest (5.15.0)
|
||||
multipart-post (2.1.1)
|
||||
nokogiri (1.13.6-arm64-darwin)
|
||||
nokogiri (1.13.9-arm64-darwin)
|
||||
racc (~> 1.4)
|
||||
nokogiri (1.13.6-x86_64-linux)
|
||||
nokogiri (1.13.9-x86_64-linux)
|
||||
racc (~> 1.4)
|
||||
octokit (4.22.0)
|
||||
faraday (>= 0.9)
|
||||
|
|
|
|||
|
|
@ -116,17 +116,14 @@ func NewInternalAPI(
|
|||
_ = federationDB.RemoveAllServersFromBlacklist()
|
||||
}
|
||||
|
||||
stats := &statistics.Statistics{
|
||||
DB: federationDB,
|
||||
FailuresUntilBlacklist: cfg.FederationMaxRetries,
|
||||
}
|
||||
stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1)
|
||||
|
||||
js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
|
||||
|
||||
queues := queue.NewOutgoingQueues(
|
||||
federationDB, base.ProcessContext,
|
||||
cfg.Matrix.DisableFederation,
|
||||
cfg.Matrix.ServerName, federation, rsAPI, stats,
|
||||
cfg.Matrix.ServerName, federation, rsAPI, &stats,
|
||||
&queue.SigningInfo{
|
||||
KeyID: cfg.Matrix.KeyID,
|
||||
PrivateKey: cfg.Matrix.PrivateKey,
|
||||
|
|
@ -183,5 +180,5 @@ func NewInternalAPI(
|
|||
}
|
||||
time.AfterFunc(time.Minute, cleanExpiredEDUs)
|
||||
|
||||
return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, stats, caches, queues, keyRing)
|
||||
return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, &stats, caches, queues, keyRing)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ import (
|
|||
|
||||
const (
|
||||
maxPDUsPerTransaction = 50
|
||||
maxEDUsPerTransaction = 50
|
||||
maxEDUsPerTransaction = 100
|
||||
maxPDUsInMemory = 128
|
||||
maxEDUsInMemory = 128
|
||||
queueIdleTimeout = time.Second * 30
|
||||
|
|
@ -64,7 +64,6 @@ type destinationQueue struct {
|
|||
pendingPDUs []*queuedPDU // PDUs waiting to be sent
|
||||
pendingEDUs []*queuedEDU // EDUs waiting to be sent
|
||||
pendingMutex sync.RWMutex // protects pendingPDUs and pendingEDUs
|
||||
interruptBackoff chan bool // interrupts backoff
|
||||
}
|
||||
|
||||
// Send event adds the event to the pending queue for the destination.
|
||||
|
|
@ -75,6 +74,7 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re
|
|||
logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a database entry that associates the given PDU NID with
|
||||
// this destination queue. We'll then be able to retrieve the PDU
|
||||
// later.
|
||||
|
|
@ -102,12 +102,12 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re
|
|||
oq.overflowed.Store(true)
|
||||
}
|
||||
oq.pendingMutex.Unlock()
|
||||
// Wake up the queue if it's asleep.
|
||||
oq.wakeQueueIfNeeded()
|
||||
select {
|
||||
case oq.notify <- struct{}{}:
|
||||
default:
|
||||
|
||||
if !oq.backingOff.Load() {
|
||||
oq.wakeQueueAndNotify()
|
||||
}
|
||||
} else {
|
||||
oq.overflowed.Store(true)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -147,12 +147,37 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share
|
|||
oq.overflowed.Store(true)
|
||||
}
|
||||
oq.pendingMutex.Unlock()
|
||||
// Wake up the queue if it's asleep.
|
||||
oq.wakeQueueIfNeeded()
|
||||
select {
|
||||
case oq.notify <- struct{}{}:
|
||||
default:
|
||||
|
||||
if !oq.backingOff.Load() {
|
||||
oq.wakeQueueAndNotify()
|
||||
}
|
||||
} else {
|
||||
oq.overflowed.Store(true)
|
||||
}
|
||||
}
|
||||
|
||||
// handleBackoffNotifier is registered as the backoff notification
|
||||
// callback with Statistics. It will wakeup and notify the queue
|
||||
// if the queue is currently backing off.
|
||||
func (oq *destinationQueue) handleBackoffNotifier() {
|
||||
// Only wake up the queue if it is backing off.
|
||||
// Otherwise there is no pending work for the queue to handle
|
||||
// so waking the queue would be a waste of resources.
|
||||
if oq.backingOff.Load() {
|
||||
oq.wakeQueueAndNotify()
|
||||
}
|
||||
}
|
||||
|
||||
// wakeQueueAndNotify ensures the destination queue is running and notifies it
|
||||
// that there is pending work.
|
||||
func (oq *destinationQueue) wakeQueueAndNotify() {
|
||||
// Wake up the queue if it's asleep.
|
||||
oq.wakeQueueIfNeeded()
|
||||
|
||||
// Notify the queue that there are events ready to send.
|
||||
select {
|
||||
case oq.notify <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -161,10 +186,11 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share
|
|||
// then we will interrupt the backoff, causing any federation
|
||||
// requests to retry.
|
||||
func (oq *destinationQueue) wakeQueueIfNeeded() {
|
||||
// If we are backing off then interrupt the backoff.
|
||||
// Clear the backingOff flag and update the backoff metrics if it was set.
|
||||
if oq.backingOff.CompareAndSwap(true, false) {
|
||||
oq.interruptBackoff <- true
|
||||
destinationQueueBackingOff.Dec()
|
||||
}
|
||||
|
||||
// If we aren't running then wake up the queue.
|
||||
if !oq.running.Load() {
|
||||
// Start the queue.
|
||||
|
|
@ -196,38 +222,54 @@ func (oq *destinationQueue) getPendingFromDatabase() {
|
|||
gotEDUs[edu.receipt.String()] = struct{}{}
|
||||
}
|
||||
|
||||
overflowed := false
|
||||
if pduCapacity := maxPDUsInMemory - len(oq.pendingPDUs); pduCapacity > 0 {
|
||||
// We have room in memory for some PDUs - let's request no more than that.
|
||||
if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, pduCapacity); err == nil {
|
||||
if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, maxPDUsInMemory); err == nil {
|
||||
if len(pdus) == maxPDUsInMemory {
|
||||
overflowed = true
|
||||
}
|
||||
for receipt, pdu := range pdus {
|
||||
if _, ok := gotPDUs[receipt.String()]; ok {
|
||||
continue
|
||||
}
|
||||
oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{receipt, pdu})
|
||||
retrieved = true
|
||||
if len(oq.pendingPDUs) == maxPDUsInMemory {
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logrus.WithError(err).Errorf("Failed to get pending PDUs for %q", oq.destination)
|
||||
}
|
||||
}
|
||||
|
||||
if eduCapacity := maxEDUsInMemory - len(oq.pendingEDUs); eduCapacity > 0 {
|
||||
// We have room in memory for some EDUs - let's request no more than that.
|
||||
if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, eduCapacity); err == nil {
|
||||
if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, maxEDUsInMemory); err == nil {
|
||||
if len(edus) == maxEDUsInMemory {
|
||||
overflowed = true
|
||||
}
|
||||
for receipt, edu := range edus {
|
||||
if _, ok := gotEDUs[receipt.String()]; ok {
|
||||
continue
|
||||
}
|
||||
oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{receipt, edu})
|
||||
retrieved = true
|
||||
if len(oq.pendingEDUs) == maxEDUsInMemory {
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logrus.WithError(err).Errorf("Failed to get pending EDUs for %q", oq.destination)
|
||||
}
|
||||
}
|
||||
|
||||
// If we've retrieved all of the events from the database with room to spare
|
||||
// in memory then we'll no longer consider this queue to be overflowed.
|
||||
if len(oq.pendingPDUs) < maxPDUsInMemory && len(oq.pendingEDUs) < maxEDUsInMemory {
|
||||
if !overflowed {
|
||||
oq.overflowed.Store(false)
|
||||
} else {
|
||||
}
|
||||
// If we've retrieved some events then notify the destination queue goroutine.
|
||||
if retrieved {
|
||||
|
|
@ -238,6 +280,24 @@ func (oq *destinationQueue) getPendingFromDatabase() {
|
|||
}
|
||||
}
|
||||
|
||||
// checkNotificationsOnClose checks for any remaining notifications
|
||||
// and starts a new backgroundSend goroutine if any exist.
|
||||
func (oq *destinationQueue) checkNotificationsOnClose() {
|
||||
// NOTE : If we are stopping the queue due to blacklist then it
|
||||
// doesn't matter if we have been notified of new work since
|
||||
// this queue instance will be deleted anyway.
|
||||
if !oq.statistics.Blacklisted() {
|
||||
select {
|
||||
case <-oq.notify:
|
||||
// We received a new notification in between the
|
||||
// idle timeout firing and stopping the goroutine.
|
||||
// Immediately restart the queue.
|
||||
oq.wakeQueueAndNotify()
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// backgroundSend is the worker goroutine for sending events.
|
||||
func (oq *destinationQueue) backgroundSend() {
|
||||
// Check if a worker is already running, and if it isn't, then
|
||||
|
|
@ -245,10 +305,17 @@ func (oq *destinationQueue) backgroundSend() {
|
|||
if !oq.running.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
|
||||
// Register queue cleanup functions.
|
||||
// NOTE : The ordering here is very intentional.
|
||||
defer oq.checkNotificationsOnClose()
|
||||
defer oq.running.Store(false)
|
||||
|
||||
destinationQueueRunning.Inc()
|
||||
defer destinationQueueRunning.Dec()
|
||||
defer oq.queues.clearQueue(oq)
|
||||
defer oq.running.Store(false)
|
||||
|
||||
idleTimeout := time.NewTimer(queueIdleTimeout)
|
||||
defer idleTimeout.Stop()
|
||||
|
||||
// Mark the queue as overflowed, so we will consult the database
|
||||
// to see if there's anything new to send.
|
||||
|
|
@ -261,59 +328,33 @@ func (oq *destinationQueue) backgroundSend() {
|
|||
oq.getPendingFromDatabase()
|
||||
}
|
||||
|
||||
// Reset the queue idle timeout.
|
||||
if !idleTimeout.Stop() {
|
||||
select {
|
||||
case <-idleTimeout.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
idleTimeout.Reset(queueIdleTimeout)
|
||||
|
||||
// If we have nothing to do then wait either for incoming events, or
|
||||
// until we hit an idle timeout.
|
||||
select {
|
||||
case <-oq.notify:
|
||||
// There's work to do, either because getPendingFromDatabase
|
||||
// told us there is, or because a new event has come in via
|
||||
// sendEvent/sendEDU.
|
||||
case <-time.After(queueIdleTimeout):
|
||||
// told us there is, a new event has come in via sendEvent/sendEDU,
|
||||
// or we are backing off and it is time to retry.
|
||||
case <-idleTimeout.C:
|
||||
// The worker is idle so stop the goroutine. It'll get
|
||||
// restarted automatically the next time we have an event to
|
||||
// send.
|
||||
return
|
||||
case <-oq.process.Context().Done():
|
||||
// The parent process is shutting down, so stop.
|
||||
oq.statistics.ClearBackoff()
|
||||
return
|
||||
}
|
||||
|
||||
// If we are backing off this server then wait for the
|
||||
// backoff duration to complete first, or until explicitly
|
||||
// told to retry.
|
||||
until, blacklisted := oq.statistics.BackoffInfo()
|
||||
if blacklisted {
|
||||
// It's been suggested that we should give up because the backoff
|
||||
// has exceeded a maximum allowable value. Clean up the in-memory
|
||||
// buffers at this point. The PDU clean-up is already on a defer.
|
||||
logrus.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination)
|
||||
oq.pendingMutex.Lock()
|
||||
for i := range oq.pendingPDUs {
|
||||
oq.pendingPDUs[i] = nil
|
||||
}
|
||||
for i := range oq.pendingEDUs {
|
||||
oq.pendingEDUs[i] = nil
|
||||
}
|
||||
oq.pendingPDUs = nil
|
||||
oq.pendingEDUs = nil
|
||||
oq.pendingMutex.Unlock()
|
||||
return
|
||||
}
|
||||
if until != nil && until.After(time.Now()) {
|
||||
// We haven't backed off yet, so wait for the suggested amount of
|
||||
// time.
|
||||
duration := time.Until(*until)
|
||||
logrus.Debugf("Backing off %q for %s", oq.destination, duration)
|
||||
oq.backingOff.Store(true)
|
||||
destinationQueueBackingOff.Inc()
|
||||
select {
|
||||
case <-time.After(duration):
|
||||
case <-oq.interruptBackoff:
|
||||
}
|
||||
destinationQueueBackingOff.Dec()
|
||||
oq.backingOff.Store(false)
|
||||
}
|
||||
|
||||
// Work out which PDUs/EDUs to include in the next transaction.
|
||||
oq.pendingMutex.RLock()
|
||||
pduCount := len(oq.pendingPDUs)
|
||||
|
|
@ -328,99 +369,52 @@ func (oq *destinationQueue) backgroundSend() {
|
|||
toSendEDUs := oq.pendingEDUs[:eduCount]
|
||||
oq.pendingMutex.RUnlock()
|
||||
|
||||
// If we didn't get anything from the database and there are no
|
||||
// pending EDUs then there's nothing to do - stop here.
|
||||
if pduCount == 0 && eduCount == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// If we have pending PDUs or EDUs then construct a transaction.
|
||||
// Try sending the next transaction and see what happens.
|
||||
transaction, pc, ec, terr := oq.nextTransaction(toSendPDUs, toSendEDUs)
|
||||
terr := oq.nextTransaction(toSendPDUs, toSendEDUs)
|
||||
if terr != nil {
|
||||
// We failed to send the transaction. Mark it as a failure.
|
||||
oq.statistics.Failure()
|
||||
|
||||
} else if transaction {
|
||||
// If we successfully sent the transaction then clear out
|
||||
// the pending events and EDUs, and wipe our transaction ID.
|
||||
oq.statistics.Success()
|
||||
oq.pendingMutex.Lock()
|
||||
for i := range oq.pendingPDUs[:pc] {
|
||||
oq.pendingPDUs[i] = nil
|
||||
_, blacklisted := oq.statistics.Failure()
|
||||
if !blacklisted {
|
||||
// Register the backoff state and exit the goroutine.
|
||||
// It'll get restarted automatically when the backoff
|
||||
// completes.
|
||||
oq.backingOff.Store(true)
|
||||
destinationQueueBackingOff.Inc()
|
||||
return
|
||||
} else {
|
||||
// Immediately trigger the blacklist logic.
|
||||
oq.blacklistDestination()
|
||||
return
|
||||
}
|
||||
for i := range oq.pendingEDUs[:ec] {
|
||||
oq.pendingEDUs[i] = nil
|
||||
}
|
||||
oq.pendingPDUs = oq.pendingPDUs[pc:]
|
||||
oq.pendingEDUs = oq.pendingEDUs[ec:]
|
||||
oq.pendingMutex.Unlock()
|
||||
} else {
|
||||
oq.handleTransactionSuccess(pduCount, eduCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// nextTransaction creates a new transaction from the pending event
|
||||
// queue and sends it. Returns true if a transaction was sent or
|
||||
// false otherwise.
|
||||
// queue and sends it.
|
||||
// Returns an error if the transaction wasn't sent.
|
||||
func (oq *destinationQueue) nextTransaction(
|
||||
pdus []*queuedPDU,
|
||||
edus []*queuedEDU,
|
||||
) (bool, int, int, error) {
|
||||
// If there's no projected transaction ID then generate one. If
|
||||
// the transaction succeeds then we'll set it back to "" so that
|
||||
// we generate a new one next time. If it fails, we'll preserve
|
||||
// it so that we retry with the same transaction ID.
|
||||
oq.transactionIDMutex.Lock()
|
||||
if oq.transactionID == "" {
|
||||
now := gomatrixserverlib.AsTimestamp(time.Now())
|
||||
oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
|
||||
}
|
||||
oq.transactionIDMutex.Unlock()
|
||||
|
||||
) error {
|
||||
// Create the transaction.
|
||||
t := gomatrixserverlib.Transaction{
|
||||
PDUs: []json.RawMessage{},
|
||||
EDUs: []gomatrixserverlib.EDU{},
|
||||
}
|
||||
t.Origin = oq.origin
|
||||
t.Destination = oq.destination
|
||||
t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now())
|
||||
t.TransactionID = oq.transactionID
|
||||
|
||||
// If we didn't get anything from the database and there are no
|
||||
// pending EDUs then there's nothing to do - stop here.
|
||||
if len(pdus) == 0 && len(edus) == 0 {
|
||||
return false, 0, 0, nil
|
||||
}
|
||||
|
||||
var pduReceipts []*shared.Receipt
|
||||
var eduReceipts []*shared.Receipt
|
||||
|
||||
// Go through PDUs that we retrieved from the database, if any,
|
||||
// and add them into the transaction.
|
||||
for _, pdu := range pdus {
|
||||
if pdu == nil || pdu.pdu == nil {
|
||||
continue
|
||||
}
|
||||
// Append the JSON of the event, since this is a json.RawMessage type in the
|
||||
// gomatrixserverlib.Transaction struct
|
||||
t.PDUs = append(t.PDUs, pdu.pdu.JSON())
|
||||
pduReceipts = append(pduReceipts, pdu.receipt)
|
||||
}
|
||||
|
||||
// Do the same for pending EDUS in the queue.
|
||||
for _, edu := range edus {
|
||||
if edu == nil || edu.edu == nil {
|
||||
continue
|
||||
}
|
||||
t.EDUs = append(t.EDUs, *edu.edu)
|
||||
eduReceipts = append(eduReceipts, edu.receipt)
|
||||
}
|
||||
|
||||
t, pduReceipts, eduReceipts := oq.createTransaction(pdus, edus)
|
||||
logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs))
|
||||
|
||||
// Try to send the transaction to the destination server.
|
||||
// TODO: we should check for 500-ish fails vs 400-ish here,
|
||||
// since we shouldn't queue things indefinitely in response
|
||||
// to a 400-ish error
|
||||
ctx, cancel := context.WithTimeout(oq.process.Context(), time.Minute*5)
|
||||
defer cancel()
|
||||
_, err := oq.client.SendTransaction(ctx, t)
|
||||
switch err.(type) {
|
||||
switch errResponse := err.(type) {
|
||||
case nil:
|
||||
// Clean up the transaction in the database.
|
||||
if pduReceipts != nil {
|
||||
|
|
@ -439,16 +433,128 @@ func (oq *destinationQueue) nextTransaction(
|
|||
oq.transactionIDMutex.Lock()
|
||||
oq.transactionID = ""
|
||||
oq.transactionIDMutex.Unlock()
|
||||
return true, len(t.PDUs), len(t.EDUs), nil
|
||||
return nil
|
||||
case gomatrix.HTTPError:
|
||||
// Report that we failed to send the transaction and we
|
||||
// will retry again, subject to backoff.
|
||||
return false, 0, 0, err
|
||||
|
||||
// TODO: we should check for 500-ish fails vs 400-ish here,
|
||||
// since we shouldn't queue things indefinitely in response
|
||||
// to a 400-ish error
|
||||
code := errResponse.Code
|
||||
logrus.Debug("Transaction failed with HTTP", code)
|
||||
return err
|
||||
default:
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"destination": oq.destination,
|
||||
logrus.ErrorKey: err,
|
||||
}).Debugf("Failed to send transaction %q", t.TransactionID)
|
||||
return false, 0, 0, err
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// createTransaction generates a gomatrixserverlib.Transaction from the provided pdus and edus.
|
||||
// It also returns the associated event receipts so they can be cleaned from the database in
|
||||
// the case of a successful transaction.
|
||||
func (oq *destinationQueue) createTransaction(
|
||||
pdus []*queuedPDU,
|
||||
edus []*queuedEDU,
|
||||
) (gomatrixserverlib.Transaction, []*shared.Receipt, []*shared.Receipt) {
|
||||
// If there's no projected transaction ID then generate one. If
|
||||
// the transaction succeeds then we'll set it back to "" so that
|
||||
// we generate a new one next time. If it fails, we'll preserve
|
||||
// it so that we retry with the same transaction ID.
|
||||
oq.transactionIDMutex.Lock()
|
||||
if oq.transactionID == "" {
|
||||
now := gomatrixserverlib.AsTimestamp(time.Now())
|
||||
oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
|
||||
}
|
||||
oq.transactionIDMutex.Unlock()
|
||||
|
||||
t := gomatrixserverlib.Transaction{
|
||||
PDUs: []json.RawMessage{},
|
||||
EDUs: []gomatrixserverlib.EDU{},
|
||||
}
|
||||
t.Origin = oq.origin
|
||||
t.Destination = oq.destination
|
||||
t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now())
|
||||
t.TransactionID = oq.transactionID
|
||||
|
||||
var pduReceipts []*shared.Receipt
|
||||
var eduReceipts []*shared.Receipt
|
||||
|
||||
// Go through PDUs that we retrieved from the database, if any,
|
||||
// and add them into the transaction.
|
||||
for _, pdu := range pdus {
|
||||
// These should never be nil.
|
||||
if pdu == nil || pdu.pdu == nil {
|
||||
continue
|
||||
}
|
||||
// Append the JSON of the event, since this is a json.RawMessage type in the
|
||||
// gomatrixserverlib.Transaction struct
|
||||
t.PDUs = append(t.PDUs, pdu.pdu.JSON())
|
||||
pduReceipts = append(pduReceipts, pdu.receipt)
|
||||
}
|
||||
|
||||
// Do the same for pending EDUS in the queue.
|
||||
for _, edu := range edus {
|
||||
// These should never be nil.
|
||||
if edu == nil || edu.edu == nil {
|
||||
continue
|
||||
}
|
||||
t.EDUs = append(t.EDUs, *edu.edu)
|
||||
eduReceipts = append(eduReceipts, edu.receipt)
|
||||
}
|
||||
|
||||
return t, pduReceipts, eduReceipts
|
||||
}
|
||||
|
||||
// blacklistDestination removes all pending PDUs and EDUs that have been cached
|
||||
// and deletes this queue.
|
||||
func (oq *destinationQueue) blacklistDestination() {
|
||||
// It's been suggested that we should give up because the backoff
|
||||
// has exceeded a maximum allowable value. Clean up the in-memory
|
||||
// buffers at this point. The PDU clean-up is already on a defer.
|
||||
logrus.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination)
|
||||
|
||||
oq.pendingMutex.Lock()
|
||||
for i := range oq.pendingPDUs {
|
||||
oq.pendingPDUs[i] = nil
|
||||
}
|
||||
for i := range oq.pendingEDUs {
|
||||
oq.pendingEDUs[i] = nil
|
||||
}
|
||||
oq.pendingPDUs = nil
|
||||
oq.pendingEDUs = nil
|
||||
oq.pendingMutex.Unlock()
|
||||
|
||||
// Delete this queue as no more messages will be sent to this
|
||||
// destination until it is no longer blacklisted.
|
||||
oq.statistics.AssignBackoffNotifier(nil)
|
||||
oq.queues.clearQueue(oq)
|
||||
}
|
||||
|
||||
// handleTransactionSuccess updates the cached event queues as well as the success and
|
||||
// backoff information for this server.
|
||||
func (oq *destinationQueue) handleTransactionSuccess(pduCount int, eduCount int) {
|
||||
// If we successfully sent the transaction then clear out
|
||||
// the pending events and EDUs, and wipe our transaction ID.
|
||||
oq.statistics.Success()
|
||||
oq.pendingMutex.Lock()
|
||||
for i := range oq.pendingPDUs[:pduCount] {
|
||||
oq.pendingPDUs[i] = nil
|
||||
}
|
||||
for i := range oq.pendingEDUs[:eduCount] {
|
||||
oq.pendingEDUs[i] = nil
|
||||
}
|
||||
oq.pendingPDUs = oq.pendingPDUs[pduCount:]
|
||||
oq.pendingEDUs = oq.pendingEDUs[eduCount:]
|
||||
oq.pendingMutex.Unlock()
|
||||
|
||||
if len(oq.pendingPDUs) > 0 || len(oq.pendingEDUs) > 0 {
|
||||
select {
|
||||
case oq.notify <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -162,23 +162,25 @@ func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *d
|
|||
if !ok || oq == nil {
|
||||
destinationQueueTotal.Inc()
|
||||
oq = &destinationQueue{
|
||||
queues: oqs,
|
||||
db: oqs.db,
|
||||
process: oqs.process,
|
||||
rsAPI: oqs.rsAPI,
|
||||
origin: oqs.origin,
|
||||
destination: destination,
|
||||
client: oqs.client,
|
||||
statistics: oqs.statistics.ForServer(destination),
|
||||
notify: make(chan struct{}, 1),
|
||||
interruptBackoff: make(chan bool),
|
||||
signing: oqs.signing,
|
||||
queues: oqs,
|
||||
db: oqs.db,
|
||||
process: oqs.process,
|
||||
rsAPI: oqs.rsAPI,
|
||||
origin: oqs.origin,
|
||||
destination: destination,
|
||||
client: oqs.client,
|
||||
statistics: oqs.statistics.ForServer(destination),
|
||||
notify: make(chan struct{}, 1),
|
||||
signing: oqs.signing,
|
||||
}
|
||||
oq.statistics.AssignBackoffNotifier(oq.handleBackoffNotifier)
|
||||
oqs.queues[destination] = oq
|
||||
}
|
||||
return oq
|
||||
}
|
||||
|
||||
// clearQueue removes the queue for the provided destination from the
|
||||
// set of destination queues.
|
||||
func (oqs *OutgoingQueues) clearQueue(oq *destinationQueue) {
|
||||
oqs.queuesMutex.Lock()
|
||||
defer oqs.queuesMutex.Unlock()
|
||||
|
|
@ -332,7 +334,9 @@ func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) {
|
|||
if oqs.disabled {
|
||||
return
|
||||
}
|
||||
oqs.statistics.ForServer(srv).RemoveBlacklist()
|
||||
if queue := oqs.getQueue(srv); queue != nil {
|
||||
queue.statistics.ClearBackoff()
|
||||
queue.wakeQueueIfNeeded()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
1047
federationapi/queue/queue_test.go
Normal file
1047
federationapi/queue/queue_test.go
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -2,6 +2,7 @@ package statistics
|
|||
|
||||
import (
|
||||
"math"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
|
@ -20,12 +21,23 @@ type Statistics struct {
|
|||
servers map[gomatrixserverlib.ServerName]*ServerStatistics
|
||||
mutex sync.RWMutex
|
||||
|
||||
backoffTimers map[gomatrixserverlib.ServerName]*time.Timer
|
||||
backoffMutex sync.RWMutex
|
||||
|
||||
// How many times should we tolerate consecutive failures before we
|
||||
// just blacklist the host altogether? The backoff is exponential,
|
||||
// so the max time here to attempt is 2**failures seconds.
|
||||
FailuresUntilBlacklist uint32
|
||||
}
|
||||
|
||||
func NewStatistics(db storage.Database, failuresUntilBlacklist uint32) Statistics {
|
||||
return Statistics{
|
||||
DB: db,
|
||||
FailuresUntilBlacklist: failuresUntilBlacklist,
|
||||
backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer),
|
||||
}
|
||||
}
|
||||
|
||||
// ForServer returns server statistics for the given server name. If it
|
||||
// does not exist, it will create empty statistics and return those.
|
||||
func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerStatistics {
|
||||
|
|
@ -45,7 +57,6 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS
|
|||
server = &ServerStatistics{
|
||||
statistics: s,
|
||||
serverName: serverName,
|
||||
interrupt: make(chan struct{}),
|
||||
}
|
||||
s.servers[serverName] = server
|
||||
s.mutex.Unlock()
|
||||
|
|
@ -64,29 +75,43 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS
|
|||
// many times we failed etc. It also manages the backoff time and black-
|
||||
// listing a remote host if it remains uncooperative.
|
||||
type ServerStatistics struct {
|
||||
statistics *Statistics //
|
||||
serverName gomatrixserverlib.ServerName //
|
||||
blacklisted atomic.Bool // is the node blacklisted
|
||||
backoffStarted atomic.Bool // is the backoff started
|
||||
backoffUntil atomic.Value // time.Time until this backoff interval ends
|
||||
backoffCount atomic.Uint32 // number of times BackoffDuration has been called
|
||||
interrupt chan struct{} // interrupts the backoff goroutine
|
||||
successCounter atomic.Uint32 // how many times have we succeeded?
|
||||
statistics *Statistics //
|
||||
serverName gomatrixserverlib.ServerName //
|
||||
blacklisted atomic.Bool // is the node blacklisted
|
||||
backoffStarted atomic.Bool // is the backoff started
|
||||
backoffUntil atomic.Value // time.Time until this backoff interval ends
|
||||
backoffCount atomic.Uint32 // number of times BackoffDuration has been called
|
||||
successCounter atomic.Uint32 // how many times have we succeeded?
|
||||
backoffNotifier func() // notifies destination queue when backoff completes
|
||||
notifierMutex sync.Mutex
|
||||
}
|
||||
|
||||
const maxJitterMultiplier = 1.4
|
||||
const minJitterMultiplier = 0.8
|
||||
|
||||
// duration returns how long the next backoff interval should be.
|
||||
func (s *ServerStatistics) duration(count uint32) time.Duration {
|
||||
return time.Second * time.Duration(math.Exp2(float64(count)))
|
||||
// Add some jitter to minimise the chance of having multiple backoffs
|
||||
// ending at the same time.
|
||||
jitter := rand.Float64()*(maxJitterMultiplier-minJitterMultiplier) + minJitterMultiplier
|
||||
duration := time.Millisecond * time.Duration(math.Exp2(float64(count))*jitter*1000)
|
||||
return duration
|
||||
}
|
||||
|
||||
// cancel will interrupt the currently active backoff.
|
||||
func (s *ServerStatistics) cancel() {
|
||||
s.blacklisted.Store(false)
|
||||
s.backoffUntil.Store(time.Time{})
|
||||
select {
|
||||
case s.interrupt <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
s.ClearBackoff()
|
||||
}
|
||||
|
||||
// AssignBackoffNotifier configures the channel to send to when
|
||||
// a backoff completes.
|
||||
func (s *ServerStatistics) AssignBackoffNotifier(notifier func()) {
|
||||
s.notifierMutex.Lock()
|
||||
defer s.notifierMutex.Unlock()
|
||||
s.backoffNotifier = notifier
|
||||
}
|
||||
|
||||
// Success updates the server statistics with a new successful
|
||||
|
|
@ -95,8 +120,8 @@ func (s *ServerStatistics) cancel() {
|
|||
// we will unblacklist it.
|
||||
func (s *ServerStatistics) Success() {
|
||||
s.cancel()
|
||||
s.successCounter.Inc()
|
||||
s.backoffCount.Store(0)
|
||||
s.successCounter.Inc()
|
||||
if s.statistics.DB != nil {
|
||||
if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil {
|
||||
logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName)
|
||||
|
|
@ -105,13 +130,17 @@ func (s *ServerStatistics) Success() {
|
|||
}
|
||||
|
||||
// Failure marks a failure and starts backing off if needed.
|
||||
// The next call to BackoffIfRequired will do the right thing
|
||||
// after this. It will return the time that the current failure
|
||||
// It will return the time that the current failure
|
||||
// will result in backoff waiting until, and a bool signalling
|
||||
// whether we have blacklisted and therefore to give up.
|
||||
func (s *ServerStatistics) Failure() (time.Time, bool) {
|
||||
// Return immediately if we have blacklisted this node.
|
||||
if s.blacklisted.Load() {
|
||||
return time.Time{}, true
|
||||
}
|
||||
|
||||
// If we aren't already backing off, this call will start
|
||||
// a new backoff period. Increase the failure counter and
|
||||
// a new backoff period, increase the failure counter and
|
||||
// start a goroutine which will wait out the backoff and
|
||||
// unset the backoffStarted flag when done.
|
||||
if s.backoffStarted.CompareAndSwap(false, true) {
|
||||
|
|
@ -122,40 +151,48 @@ func (s *ServerStatistics) Failure() (time.Time, bool) {
|
|||
logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName)
|
||||
}
|
||||
}
|
||||
s.ClearBackoff()
|
||||
return time.Time{}, true
|
||||
}
|
||||
|
||||
go func() {
|
||||
until, ok := s.backoffUntil.Load().(time.Time)
|
||||
if ok && !until.IsZero() {
|
||||
select {
|
||||
case <-time.After(time.Until(until)):
|
||||
case <-s.interrupt:
|
||||
}
|
||||
s.backoffStarted.Store(false)
|
||||
}
|
||||
}()
|
||||
// We're starting a new back off so work out what the next interval
|
||||
// will be.
|
||||
count := s.backoffCount.Load()
|
||||
until := time.Now().Add(s.duration(count))
|
||||
s.backoffUntil.Store(until)
|
||||
|
||||
s.statistics.backoffMutex.Lock()
|
||||
defer s.statistics.backoffMutex.Unlock()
|
||||
s.statistics.backoffTimers[s.serverName] = time.AfterFunc(time.Until(until), s.backoffFinished)
|
||||
}
|
||||
|
||||
// Check if we have blacklisted this node.
|
||||
if s.blacklisted.Load() {
|
||||
return time.Now(), true
|
||||
}
|
||||
return s.backoffUntil.Load().(time.Time), false
|
||||
}
|
||||
|
||||
// If we're already backing off and we haven't yet surpassed
|
||||
// the deadline then return that. Repeated calls to Failure
|
||||
// within a single backoff interval will have no side effects.
|
||||
if until, ok := s.backoffUntil.Load().(time.Time); ok && !time.Now().After(until) {
|
||||
return until, false
|
||||
// ClearBackoff stops the backoff timer for this destination if it is running
|
||||
// and removes the timer from the backoffTimers map.
|
||||
func (s *ServerStatistics) ClearBackoff() {
|
||||
// If the timer is still running then stop it so it's memory is cleaned up sooner.
|
||||
s.statistics.backoffMutex.Lock()
|
||||
defer s.statistics.backoffMutex.Unlock()
|
||||
if timer, ok := s.statistics.backoffTimers[s.serverName]; ok {
|
||||
timer.Stop()
|
||||
}
|
||||
delete(s.statistics.backoffTimers, s.serverName)
|
||||
|
||||
// We're either backing off and have passed the deadline, or
|
||||
// we aren't backing off, so work out what the next interval
|
||||
// will be.
|
||||
count := s.backoffCount.Load()
|
||||
until := time.Now().Add(s.duration(count))
|
||||
s.backoffUntil.Store(until)
|
||||
return until, false
|
||||
s.backoffStarted.Store(false)
|
||||
}
|
||||
|
||||
// backoffFinished will clear the previous backoff and notify the destination queue.
|
||||
func (s *ServerStatistics) backoffFinished() {
|
||||
s.ClearBackoff()
|
||||
|
||||
// Notify the destinationQueue if one is currently running.
|
||||
s.notifierMutex.Lock()
|
||||
defer s.notifierMutex.Unlock()
|
||||
if s.backoffNotifier != nil {
|
||||
s.backoffNotifier()
|
||||
}
|
||||
}
|
||||
|
||||
// BackoffInfo returns information about the current or previous backoff.
|
||||
|
|
@ -174,6 +211,12 @@ func (s *ServerStatistics) Blacklisted() bool {
|
|||
return s.blacklisted.Load()
|
||||
}
|
||||
|
||||
// RemoveBlacklist removes the blacklisted status from the server.
|
||||
func (s *ServerStatistics) RemoveBlacklist() {
|
||||
s.cancel()
|
||||
s.backoffCount.Store(0)
|
||||
}
|
||||
|
||||
// SuccessCount returns the number of successful requests. This is
|
||||
// usually useful in constructing transaction IDs.
|
||||
func (s *ServerStatistics) SuccessCount() uint32 {
|
||||
|
|
|
|||
|
|
@ -7,9 +7,7 @@ import (
|
|||
)
|
||||
|
||||
func TestBackoff(t *testing.T) {
|
||||
stats := Statistics{
|
||||
FailuresUntilBlacklist: 7,
|
||||
}
|
||||
stats := NewStatistics(nil, 7)
|
||||
server := ServerStatistics{
|
||||
statistics: &stats,
|
||||
serverName: "test.com",
|
||||
|
|
@ -36,7 +34,7 @@ func TestBackoff(t *testing.T) {
|
|||
|
||||
// Get the duration.
|
||||
_, blacklist := server.BackoffInfo()
|
||||
duration := time.Until(until).Round(time.Second)
|
||||
duration := time.Until(until)
|
||||
|
||||
// Unset the backoff, or otherwise our next call will think that
|
||||
// there's a backoff in progress and return the same result.
|
||||
|
|
@ -57,8 +55,17 @@ func TestBackoff(t *testing.T) {
|
|||
|
||||
// Check if the duration is what we expect.
|
||||
t.Logf("Backoff %d is for %s", i, duration)
|
||||
if wanted := time.Second * time.Duration(math.Exp2(float64(i))); !blacklist && duration != wanted {
|
||||
t.Fatalf("Backoff %d should have been %s but was %s", i, wanted, duration)
|
||||
roundingAllowance := 0.01
|
||||
minDuration := time.Millisecond * time.Duration(math.Exp2(float64(i))*minJitterMultiplier*1000-roundingAllowance)
|
||||
maxDuration := time.Millisecond * time.Duration(math.Exp2(float64(i))*maxJitterMultiplier*1000+roundingAllowance)
|
||||
var inJitterRange bool
|
||||
if duration >= minDuration && duration <= maxDuration {
|
||||
inJitterRange = true
|
||||
} else {
|
||||
inJitterRange = false
|
||||
}
|
||||
if !blacklist && !inJitterRange {
|
||||
t.Fatalf("Backoff %d should have been between %s and %s but was %s", i, minDuration, maxDuration, duration)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -52,6 +52,10 @@ type Receipt struct {
|
|||
nid int64
|
||||
}
|
||||
|
||||
func NewReceipt(nid int64) Receipt {
|
||||
return Receipt{nid: nid}
|
||||
}
|
||||
|
||||
func (r *Receipt) String() string {
|
||||
return fmt.Sprintf("%d", r.nid)
|
||||
}
|
||||
|
|
|
|||
4
go.mod
4
go.mod
|
|
@ -22,7 +22,7 @@ require (
|
|||
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
|
||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
|
||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20221014061925-a132619fa241
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20221018085104-a72a83f0e19a
|
||||
github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3
|
||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
||||
github.com/mattn/go-sqlite3 v1.14.15
|
||||
|
|
@ -50,6 +50,7 @@ require (
|
|||
golang.org/x/term v0.0.0-20220919170432-7a66f970e087
|
||||
gopkg.in/h2non/bimg.v1 v1.1.9
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
gotest.tools/v3 v3.4.0
|
||||
nhooyr.io/websocket v1.8.7
|
||||
)
|
||||
|
||||
|
|
@ -127,7 +128,6 @@ require (
|
|||
gopkg.in/macaroon.v2 v2.1.0 // indirect
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
gotest.tools/v3 v3.4.0 // indirect
|
||||
)
|
||||
|
||||
go 1.18
|
||||
|
|
|
|||
4
go.sum
4
go.sum
|
|
@ -387,8 +387,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
|
|||
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
|
||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
|
||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20221014061925-a132619fa241 h1:e5o68MWeU7wjTvvNKmVo655oCYesoNRoPeBb1Xfz54g=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20221014061925-a132619fa241/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20221018085104-a72a83f0e19a h1:bQKHk3AWlgm7XhzPhuU3Iw3pUptW5l1DR/1y0o7zCKQ=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20221018085104-a72a83f0e19a/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4=
|
||||
github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3 h1:lzkSQvBv8TuqKJCPoVwOVvEnARTlua5rrNy/Qw2Vxeo=
|
||||
github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k=
|
||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@
|
|||
package transactions
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
|
@ -29,6 +31,7 @@ type txnsMap map[CacheKey]*util.JSONResponse
|
|||
type CacheKey struct {
|
||||
AccessToken string
|
||||
TxnID string
|
||||
Endpoint string
|
||||
}
|
||||
|
||||
// Cache represents a temporary store for response entries.
|
||||
|
|
@ -57,14 +60,14 @@ func NewWithCleanupPeriod(cleanupPeriod time.Duration) *Cache {
|
|||
return &t
|
||||
}
|
||||
|
||||
// FetchTransaction looks up an entry for the (accessToken, txnID) tuple in Cache.
|
||||
// FetchTransaction looks up an entry for the (accessToken, txnID, req.URL) tuple in Cache.
|
||||
// Looks in both the txnMaps.
|
||||
// Returns (JSON response, true) if txnID is found, else the returned bool is false.
|
||||
func (t *Cache) FetchTransaction(accessToken, txnID string) (*util.JSONResponse, bool) {
|
||||
func (t *Cache) FetchTransaction(accessToken, txnID string, u *url.URL) (*util.JSONResponse, bool) {
|
||||
t.RLock()
|
||||
defer t.RUnlock()
|
||||
for _, txns := range t.txnsMaps {
|
||||
res, ok := txns[CacheKey{accessToken, txnID}]
|
||||
res, ok := txns[CacheKey{accessToken, txnID, filepath.Dir(u.Path)}]
|
||||
if ok {
|
||||
return res, true
|
||||
}
|
||||
|
|
@ -72,13 +75,12 @@ func (t *Cache) FetchTransaction(accessToken, txnID string) (*util.JSONResponse,
|
|||
return nil, false
|
||||
}
|
||||
|
||||
// AddTransaction adds an entry for the (accessToken, txnID) tuple in Cache.
|
||||
// AddTransaction adds an entry for the (accessToken, txnID, req.URL) tuple in Cache.
|
||||
// Adds to the front txnMap.
|
||||
func (t *Cache) AddTransaction(accessToken, txnID string, res *util.JSONResponse) {
|
||||
func (t *Cache) AddTransaction(accessToken, txnID string, u *url.URL, res *util.JSONResponse) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
|
||||
t.txnsMaps[0][CacheKey{accessToken, txnID}] = res
|
||||
t.txnsMaps[0][CacheKey{accessToken, txnID, filepath.Dir(u.Path)}] = res
|
||||
}
|
||||
|
||||
// cacheCleanService is responsible for cleaning up entries after cleanupPeriod.
|
||||
|
|
|
|||
|
|
@ -14,6 +14,9 @@ package transactions
|
|||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
|
|
@ -24,6 +27,16 @@ type fakeType struct {
|
|||
ID string `json:"ID"`
|
||||
}
|
||||
|
||||
func TestCompare(t *testing.T) {
|
||||
u1, _ := url.Parse("/send/1?accessToken=123")
|
||||
u2, _ := url.Parse("/send/1")
|
||||
c1 := CacheKey{"1", "2", filepath.Dir(u1.Path)}
|
||||
c2 := CacheKey{"1", "2", filepath.Dir(u2.Path)}
|
||||
if !reflect.DeepEqual(c1, c2) {
|
||||
t.Fatalf("Cache keys differ: %+v <> %+v", c1, c2)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
fakeAccessToken = "aRandomAccessToken"
|
||||
fakeAccessToken2 = "anotherRandomAccessToken"
|
||||
|
|
@ -34,23 +47,28 @@ var (
|
|||
fakeResponse2 = &util.JSONResponse{
|
||||
Code: http.StatusOK, JSON: fakeType{ID: "1"},
|
||||
}
|
||||
fakeResponse3 = &util.JSONResponse{
|
||||
Code: http.StatusOK, JSON: fakeType{ID: "2"},
|
||||
}
|
||||
)
|
||||
|
||||
// TestCache creates a New Cache and tests AddTransaction & FetchTransaction
|
||||
func TestCache(t *testing.T) {
|
||||
fakeTxnCache := New()
|
||||
fakeTxnCache.AddTransaction(fakeAccessToken, fakeTxnID, fakeResponse)
|
||||
u, _ := url.Parse("")
|
||||
fakeTxnCache.AddTransaction(fakeAccessToken, fakeTxnID, u, fakeResponse)
|
||||
|
||||
// Add entries for noise.
|
||||
for i := 1; i <= 100; i++ {
|
||||
fakeTxnCache.AddTransaction(
|
||||
fakeAccessToken,
|
||||
fakeTxnID+strconv.Itoa(i),
|
||||
u,
|
||||
&util.JSONResponse{Code: http.StatusOK, JSON: fakeType{ID: strconv.Itoa(i)}},
|
||||
)
|
||||
}
|
||||
|
||||
testResponse, ok := fakeTxnCache.FetchTransaction(fakeAccessToken, fakeTxnID)
|
||||
testResponse, ok := fakeTxnCache.FetchTransaction(fakeAccessToken, fakeTxnID, u)
|
||||
if !ok {
|
||||
t.Error("Failed to retrieve entry for txnID: ", fakeTxnID)
|
||||
} else if testResponse.JSON != fakeResponse.JSON {
|
||||
|
|
@ -59,20 +77,30 @@ func TestCache(t *testing.T) {
|
|||
}
|
||||
|
||||
// TestCacheScope ensures transactions with the same transaction ID are not shared
|
||||
// across multiple access tokens.
|
||||
// across multiple access tokens and endpoints.
|
||||
func TestCacheScope(t *testing.T) {
|
||||
cache := New()
|
||||
cache.AddTransaction(fakeAccessToken, fakeTxnID, fakeResponse)
|
||||
cache.AddTransaction(fakeAccessToken2, fakeTxnID, fakeResponse2)
|
||||
sendEndpoint, _ := url.Parse("/send/1?accessToken=test")
|
||||
sendToDeviceEndpoint, _ := url.Parse("/sendToDevice/1")
|
||||
cache.AddTransaction(fakeAccessToken, fakeTxnID, sendEndpoint, fakeResponse)
|
||||
cache.AddTransaction(fakeAccessToken2, fakeTxnID, sendEndpoint, fakeResponse2)
|
||||
cache.AddTransaction(fakeAccessToken2, fakeTxnID, sendToDeviceEndpoint, fakeResponse3)
|
||||
|
||||
if res, ok := cache.FetchTransaction(fakeAccessToken, fakeTxnID); !ok {
|
||||
if res, ok := cache.FetchTransaction(fakeAccessToken, fakeTxnID, sendEndpoint); !ok {
|
||||
t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID)
|
||||
} else if res.JSON != fakeResponse.JSON {
|
||||
t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse.JSON, res.JSON)
|
||||
}
|
||||
if res, ok := cache.FetchTransaction(fakeAccessToken2, fakeTxnID); !ok {
|
||||
if res, ok := cache.FetchTransaction(fakeAccessToken2, fakeTxnID, sendEndpoint); !ok {
|
||||
t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID)
|
||||
} else if res.JSON != fakeResponse2.JSON {
|
||||
t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse2.JSON, res.JSON)
|
||||
}
|
||||
|
||||
// Ensure the txnID is not shared across endpoints
|
||||
if res, ok := cache.FetchTransaction(fakeAccessToken2, fakeTxnID, sendToDeviceEndpoint); !ok {
|
||||
t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID)
|
||||
} else if res.JSON != fakeResponse3.JSON {
|
||||
t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse2.JSON, res.JSON)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -250,6 +250,7 @@ func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *ap
|
|||
|
||||
// nolint:gocyclo
|
||||
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error {
|
||||
var respMu sync.Mutex
|
||||
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
|
||||
res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
|
||||
res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
|
||||
|
|
@ -329,7 +330,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
|
|||
}
|
||||
|
||||
// attempt to satisfy key queries from the local database first as we should get device updates pushed to us
|
||||
domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys)
|
||||
domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, &respMu, domainToDeviceKeys)
|
||||
if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 {
|
||||
// perform key queries for remote devices
|
||||
a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys)
|
||||
|
|
@ -407,7 +408,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
|
|||
}
|
||||
|
||||
func (a *KeyInternalAPI) remoteKeysFromDatabase(
|
||||
ctx context.Context, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string,
|
||||
ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string,
|
||||
) map[string]map[string][]string {
|
||||
fetchRemote := make(map[string]map[string][]string)
|
||||
for domain, userToDeviceMap := range domainToDeviceKeys {
|
||||
|
|
@ -415,7 +416,7 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase(
|
|||
// we can't safely return keys from the db when all devices are requested as we don't
|
||||
// know if one has just been added.
|
||||
if len(deviceIDs) > 0 {
|
||||
err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, deviceIDs)
|
||||
err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, deviceIDs)
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
|
|
@ -541,9 +542,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
|
|||
}
|
||||
// refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this
|
||||
// user so the fact that we're populating all devices here isn't a problem so long as we have devices.
|
||||
respMu.Lock()
|
||||
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, nil)
|
||||
respMu.Unlock()
|
||||
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
logrus.ErrorKey: err,
|
||||
|
|
@ -567,25 +566,26 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
|
|||
res.Failures[serverName] = map[string]interface{}{
|
||||
"message": err.Error(),
|
||||
}
|
||||
respMu.Unlock()
|
||||
|
||||
// last ditch, use the cache only. This is good for when clients hit /keys/query and the remote server
|
||||
// is down, better to return something than nothing at all. Clients can know about the failure by
|
||||
// inspecting the failures map though so they can know it's a cached response.
|
||||
for userID, dkeys := range devKeys {
|
||||
// drop the error as it's already a failure at this point
|
||||
_ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, dkeys)
|
||||
_ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, dkeys)
|
||||
}
|
||||
|
||||
// Sytest expects no failures, if we still could retrieve keys, e.g. from local cache
|
||||
respMu.Lock()
|
||||
if len(res.DeviceKeys) > 0 {
|
||||
delete(res.Failures, serverName)
|
||||
}
|
||||
respMu.Unlock()
|
||||
|
||||
}
|
||||
|
||||
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
|
||||
ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string,
|
||||
ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string,
|
||||
) error {
|
||||
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
|
||||
// if we can't query the db or there are fewer keys than requested, fetch from remote.
|
||||
|
|
@ -598,9 +598,11 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
|
|||
if len(deviceIDs) == 0 && len(keys) == 0 {
|
||||
return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID)
|
||||
}
|
||||
respMu.Lock()
|
||||
if res.DeviceKeys[userID] == nil {
|
||||
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
||||
}
|
||||
respMu.Unlock()
|
||||
|
||||
for _, key := range keys {
|
||||
if len(key.KeyJSON) == 0 {
|
||||
|
|
@ -610,7 +612,9 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
|
|||
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
|
||||
DisplayName string `json:"device_display_name,omitempty"`
|
||||
}{key.DisplayName})
|
||||
respMu.Lock()
|
||||
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
|
||||
respMu.Unlock()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,19 +10,24 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/test/testrig"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
)
|
||||
|
||||
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
||||
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
|
||||
t.Helper()
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{
|
||||
db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, "", 4, 0, 0, "")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create new user db: %v", err)
|
||||
}
|
||||
return db, close
|
||||
return db, func() {
|
||||
close()
|
||||
baseclose()
|
||||
}
|
||||
}
|
||||
|
||||
func mustCreateEvent(t *testing.T, content string) *gomatrixserverlib.HeaderedEvent {
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ import (
|
|||
|
||||
const accountDataSchema = `
|
||||
-- Stores data about accounts data.
|
||||
CREATE TABLE IF NOT EXISTS account_data (
|
||||
CREATE TABLE IF NOT EXISTS userapi_account_datas (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL,
|
||||
-- The room ID for this data (empty string if not specific to a room)
|
||||
|
|
@ -41,15 +41,15 @@ CREATE TABLE IF NOT EXISTS account_data (
|
|||
`
|
||||
|
||||
const insertAccountDataSQL = `
|
||||
INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
|
||||
INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
|
||||
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content
|
||||
`
|
||||
|
||||
const selectAccountDataSQL = "" +
|
||||
"SELECT room_id, type, content FROM account_data WHERE localpart = $1"
|
||||
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1"
|
||||
|
||||
const selectAccountDataByTypeSQL = "" +
|
||||
"SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
|
||||
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3"
|
||||
|
||||
type accountDataStatements struct {
|
||||
insertAccountDataStmt *sql.Stmt
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ import (
|
|||
|
||||
const accountsSchema = `
|
||||
-- Stores data about accounts.
|
||||
CREATE TABLE IF NOT EXISTS account_accounts (
|
||||
CREATE TABLE IF NOT EXISTS userapi_accounts (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
-- When this account was first created, as a unix timestamp (ms resolution).
|
||||
|
|
@ -51,22 +51,22 @@ CREATE TABLE IF NOT EXISTS account_accounts (
|
|||
`
|
||||
|
||||
const insertAccountSQL = "" +
|
||||
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
|
||||
"INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
|
||||
|
||||
const updatePasswordSQL = "" +
|
||||
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||
|
||||
const deactivateAccountSQL = "" +
|
||||
"UPDATE account_accounts SET is_deactivated = TRUE WHERE localpart = $1"
|
||||
"UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1"
|
||||
|
||||
const selectAccountByLocalpartSQL = "" +
|
||||
"SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1"
|
||||
"SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1"
|
||||
|
||||
const selectPasswordHashSQL = "" +
|
||||
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
|
||||
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
|
||||
|
||||
const selectNewNumericLocalpartSQL = "" +
|
||||
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM account_accounts WHERE localpart ~ '^[0-9]{1,}$'"
|
||||
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$'"
|
||||
|
||||
type accountsStatements struct {
|
||||
insertAccountStmt *sql.Stmt
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import (
|
|||
)
|
||||
|
||||
func UpIsActive(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS is_deactivated BOOLEAN DEFAULT FALSE;")
|
||||
_, err := tx.ExecContext(ctx, "ALTER TABLE userapi_accounts ADD COLUMN IF NOT EXISTS is_deactivated BOOLEAN DEFAULT FALSE;")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
|
|
@ -15,7 +15,7 @@ func UpIsActive(ctx context.Context, tx *sql.Tx) error {
|
|||
}
|
||||
|
||||
func DownIsActive(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts DROP COLUMN is_deactivated;")
|
||||
_, err := tx.ExecContext(ctx, "ALTER TABLE userapi_accounts DROP COLUMN is_deactivated;")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,9 +8,9 @@ import (
|
|||
|
||||
func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS last_seen_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM CURRENT_TIMESTAMP)*1000;
|
||||
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS ip TEXT;
|
||||
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`)
|
||||
ALTER TABLE userapi_devices ADD COLUMN IF NOT EXISTS last_seen_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM CURRENT_TIMESTAMP)*1000;
|
||||
ALTER TABLE userapi_devices ADD COLUMN IF NOT EXISTS ip TEXT;
|
||||
ALTER TABLE userapi_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
|
|
@ -19,9 +19,9 @@ ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`)
|
|||
|
||||
func DownLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
ALTER TABLE device_devices DROP COLUMN last_seen_ts;
|
||||
ALTER TABLE device_devices DROP COLUMN ip;
|
||||
ALTER TABLE device_devices DROP COLUMN user_agent;`)
|
||||
ALTER TABLE userapi_devices DROP COLUMN last_seen_ts;
|
||||
ALTER TABLE userapi_devices DROP COLUMN ip;
|
||||
ALTER TABLE userapi_devices DROP COLUMN user_agent;`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@ import (
|
|||
func UpAddAccountType(ctx context.Context, tx *sql.Tx) error {
|
||||
// initially set every account to useraccount, change appservice and guest accounts afterwards
|
||||
// (user = 1, guest = 2, admin = 3, appservice = 4)
|
||||
_, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS account_type SMALLINT NOT NULL DEFAULT 1;
|
||||
UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> '';
|
||||
UPDATE account_accounts SET account_type = 2 WHERE localpart ~ '^[0-9]+$';
|
||||
ALTER TABLE account_accounts ALTER COLUMN account_type DROP DEFAULT;`,
|
||||
_, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts ADD COLUMN IF NOT EXISTS account_type SMALLINT NOT NULL DEFAULT 1;
|
||||
UPDATE userapi_accounts SET account_type = 4 WHERE appservice_id <> '';
|
||||
UPDATE userapi_accounts SET account_type = 2 WHERE localpart ~ '^[0-9]+$';
|
||||
ALTER TABLE userapi_accounts ALTER COLUMN account_type DROP DEFAULT;`,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
|
|
@ -21,7 +21,7 @@ ALTER TABLE account_accounts ALTER COLUMN account_type DROP DEFAULT;`,
|
|||
}
|
||||
|
||||
func DownAddAccountType(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts DROP COLUMN account_type;")
|
||||
_, err := tx.ExecContext(ctx, "ALTER TABLE userapi_accounts DROP COLUMN account_type;")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,102 @@
|
|||
package deltas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
var renameTableMappings = map[string]string{
|
||||
"account_accounts": "userapi_accounts",
|
||||
"account_data": "userapi_account_datas",
|
||||
"device_devices": "userapi_devices",
|
||||
"account_e2e_room_keys": "userapi_key_backups",
|
||||
"account_e2e_room_keys_versions": "userapi_key_backup_versions",
|
||||
"login_tokens": "userapi_login_tokens",
|
||||
"open_id_tokens": "userapi_openid_tokens",
|
||||
"account_profiles": "userapi_profiles",
|
||||
"account_threepid": "userapi_threepids",
|
||||
}
|
||||
|
||||
var renameSequenceMappings = map[string]string{
|
||||
"device_session_id_seq": "userapi_device_session_id_seq",
|
||||
"account_e2e_room_keys_versions_seq": "userapi_key_backup_versions_seq",
|
||||
}
|
||||
|
||||
var renameIndicesMappings = map[string]string{
|
||||
"device_localpart_id_idx": "userapi_device_localpart_id_idx",
|
||||
"e2e_room_keys_idx": "userapi_key_backups_idx",
|
||||
"e2e_room_keys_versions_idx": "userapi_key_backups_versions_idx",
|
||||
"account_e2e_room_keys_versions_idx": "userapi_key_backup_versions_idx",
|
||||
"login_tokens_expiration_idx": "userapi_login_tokens_expiration_idx",
|
||||
"account_threepid_localpart": "userapi_threepid_idx",
|
||||
}
|
||||
|
||||
// I know what you're thinking: you're wondering "why doesn't this use $1
|
||||
// and pass variadic parameters to ExecContext?" — the answer is because
|
||||
// PostgreSQL doesn't expect the table name to be specified as a substituted
|
||||
// argument in that way so it results in a syntax error in the query.
|
||||
|
||||
func UpRenameTables(ctx context.Context, tx *sql.Tx) error {
|
||||
for old, new := range renameTableMappings {
|
||||
q := fmt.Sprintf(
|
||||
"ALTER TABLE IF EXISTS %s RENAME TO %s;",
|
||||
pq.QuoteIdentifier(old), pq.QuoteIdentifier(new),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("rename table %q to %q error: %w", old, new, err)
|
||||
}
|
||||
}
|
||||
for old, new := range renameSequenceMappings {
|
||||
q := fmt.Sprintf(
|
||||
"ALTER SEQUENCE IF EXISTS %s RENAME TO %s;",
|
||||
pq.QuoteIdentifier(old), pq.QuoteIdentifier(new),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("rename table %q to %q error: %w", old, new, err)
|
||||
}
|
||||
}
|
||||
for old, new := range renameIndicesMappings {
|
||||
q := fmt.Sprintf(
|
||||
"ALTER INDEX IF EXISTS %s RENAME TO %s;",
|
||||
pq.QuoteIdentifier(old), pq.QuoteIdentifier(new),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("rename table %q to %q error: %w", old, new, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DownRenameTables(ctx context.Context, tx *sql.Tx) error {
|
||||
for old, new := range renameTableMappings {
|
||||
q := fmt.Sprintf(
|
||||
"ALTER TABLE IF EXISTS %s RENAME TO %s;",
|
||||
pq.QuoteIdentifier(new), pq.QuoteIdentifier(old),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("rename table %q to %q error: %w", new, old, err)
|
||||
}
|
||||
}
|
||||
for old, new := range renameSequenceMappings {
|
||||
q := fmt.Sprintf(
|
||||
"ALTER SEQUENCE IF EXISTS %s RENAME TO %s;",
|
||||
pq.QuoteIdentifier(new), pq.QuoteIdentifier(old),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("rename table %q to %q error: %w", new, old, err)
|
||||
}
|
||||
}
|
||||
for old, new := range renameIndicesMappings {
|
||||
q := fmt.Sprintf(
|
||||
"ALTER INDEX IF EXISTS %s RENAME TO %s;",
|
||||
pq.QuoteIdentifier(new), pq.QuoteIdentifier(old),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("rename table %q to %q error: %w", new, old, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -31,10 +31,10 @@ import (
|
|||
|
||||
const devicesSchema = `
|
||||
-- This sequence is used for automatic allocation of session_id.
|
||||
CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
|
||||
CREATE SEQUENCE IF NOT EXISTS userapi_device_session_id_seq START 1;
|
||||
|
||||
-- Stores data about devices.
|
||||
CREATE TABLE IF NOT EXISTS device_devices (
|
||||
CREATE TABLE IF NOT EXISTS userapi_devices (
|
||||
-- The access token granted to this device. This has to be the primary key
|
||||
-- so we can distinguish which device is making a given request.
|
||||
access_token TEXT NOT NULL PRIMARY KEY,
|
||||
|
|
@ -42,7 +42,7 @@ CREATE TABLE IF NOT EXISTS device_devices (
|
|||
-- This can be used as a secure substitution of the access token in situations
|
||||
-- where data is associated with access tokens (e.g. transaction storage),
|
||||
-- so we don't have to store users' access tokens everywhere.
|
||||
session_id BIGINT NOT NULL DEFAULT nextval('device_session_id_seq'),
|
||||
session_id BIGINT NOT NULL DEFAULT nextval('userapi_device_session_id_seq'),
|
||||
-- The device identifier. This only needs to uniquely identify a device for a given user, not globally.
|
||||
-- access_tokens will be clobbered based on the device ID for a user.
|
||||
device_id TEXT NOT NULL,
|
||||
|
|
@ -65,39 +65,39 @@ CREATE TABLE IF NOT EXISTS device_devices (
|
|||
);
|
||||
|
||||
-- Device IDs must be unique for a given user.
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS device_localpart_id_idx ON device_devices(localpart, device_id);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, device_id);
|
||||
`
|
||||
|
||||
const insertDeviceSQL = "" +
|
||||
"INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" +
|
||||
"INSERT INTO userapi_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" +
|
||||
" RETURNING session_id"
|
||||
|
||||
const selectDeviceByTokenSQL = "" +
|
||||
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
|
||||
"SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1"
|
||||
|
||||
const selectDeviceByIDSQL = "" +
|
||||
"SELECT display_name, last_seen_ts, ip FROM device_devices WHERE localpart = $1 and device_id = $2"
|
||||
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2"
|
||||
|
||||
const selectDevicesByLocalpartSQL = "" +
|
||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
|
||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
|
||||
|
||||
const updateDeviceNameSQL = "" +
|
||||
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||
|
||||
const deleteDeviceSQL = "" +
|
||||
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
|
||||
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2"
|
||||
|
||||
const deleteDevicesByLocalpartSQL = "" +
|
||||
"DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
|
||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
|
||||
|
||||
const deleteDevicesSQL = "" +
|
||||
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
|
||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id = ANY($2)"
|
||||
|
||||
const selectDevicesByIDSQL = "" +
|
||||
"SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
|
||||
"SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
|
||||
|
||||
const updateDeviceLastSeen = "" +
|
||||
"UPDATE device_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
|
||||
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
|
||||
|
||||
type devicesStatements struct {
|
||||
insertDeviceStmt *sql.Stmt
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ import (
|
|||
)
|
||||
|
||||
const keyBackupTableSchema = `
|
||||
CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
|
||||
CREATE TABLE IF NOT EXISTS userapi_key_backups (
|
||||
user_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
session_id TEXT NOT NULL,
|
||||
|
|
@ -37,31 +37,31 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
|
|||
is_verified BOOLEAN NOT NULL,
|
||||
session_data TEXT NOT NULL
|
||||
);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version);
|
||||
CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_key_backups_idx ON userapi_key_backups(user_id, room_id, session_id, version);
|
||||
CREATE INDEX IF NOT EXISTS userapi_key_backups_versions_idx ON userapi_key_backups(user_id, version);
|
||||
`
|
||||
|
||||
const insertBackupKeySQL = "" +
|
||||
"INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " +
|
||||
"INSERT INTO userapi_key_backups(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " +
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
|
||||
|
||||
const updateBackupKeySQL = "" +
|
||||
"UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " +
|
||||
"UPDATE userapi_key_backups SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " +
|
||||
"WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8"
|
||||
|
||||
const countKeysSQL = "" +
|
||||
"SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2"
|
||||
"SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2"
|
||||
|
||||
const selectKeysSQL = "" +
|
||||
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
|
||||
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
|
||||
"WHERE user_id = $1 AND version = $2"
|
||||
|
||||
const selectKeysByRoomIDSQL = "" +
|
||||
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
|
||||
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
|
||||
"WHERE user_id = $1 AND version = $2 AND room_id = $3"
|
||||
|
||||
const selectKeysByRoomIDAndSessionIDSQL = "" +
|
||||
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
|
||||
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
|
||||
"WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4"
|
||||
|
||||
type keyBackupStatements struct {
|
||||
|
|
|
|||
|
|
@ -26,40 +26,40 @@ import (
|
|||
)
|
||||
|
||||
const keyBackupVersionTableSchema = `
|
||||
CREATE SEQUENCE IF NOT EXISTS account_e2e_room_keys_versions_seq;
|
||||
CREATE SEQUENCE IF NOT EXISTS userapi_key_backup_versions_seq;
|
||||
|
||||
-- the metadata for each generation of encrypted e2e session backups
|
||||
CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions (
|
||||
CREATE TABLE IF NOT EXISTS userapi_key_backup_versions (
|
||||
user_id TEXT NOT NULL,
|
||||
-- this means no 2 users will ever have the same version of e2e session backups which strictly
|
||||
-- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1.
|
||||
version BIGINT DEFAULT nextval('account_e2e_room_keys_versions_seq'),
|
||||
version BIGINT DEFAULT nextval('userapi_key_backup_versions_seq'),
|
||||
algorithm TEXT NOT NULL,
|
||||
auth_data TEXT NOT NULL,
|
||||
etag TEXT NOT NULL,
|
||||
deleted SMALLINT DEFAULT 0 NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_key_backup_versions_idx ON userapi_key_backup_versions(user_id, version);
|
||||
`
|
||||
|
||||
const insertKeyBackupSQL = "" +
|
||||
"INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version"
|
||||
"INSERT INTO userapi_key_backup_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version"
|
||||
|
||||
const updateKeyBackupAuthDataSQL = "" +
|
||||
"UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3"
|
||||
"UPDATE userapi_key_backup_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3"
|
||||
|
||||
const updateKeyBackupETagSQL = "" +
|
||||
"UPDATE account_e2e_room_keys_versions SET etag = $1 WHERE user_id = $2 AND version = $3"
|
||||
"UPDATE userapi_key_backup_versions SET etag = $1 WHERE user_id = $2 AND version = $3"
|
||||
|
||||
const deleteKeyBackupSQL = "" +
|
||||
"UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2"
|
||||
"UPDATE userapi_key_backup_versions SET deleted=1 WHERE user_id = $1 AND version = $2"
|
||||
|
||||
const selectKeyBackupSQL = "" +
|
||||
"SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2"
|
||||
"SELECT algorithm, auth_data, etag, deleted FROM userapi_key_backup_versions WHERE user_id = $1 AND version = $2"
|
||||
|
||||
const selectLatestVersionSQL = "" +
|
||||
"SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1"
|
||||
"SELECT MAX(version) FROM userapi_key_backup_versions WHERE user_id = $1"
|
||||
|
||||
type keyBackupVersionStatements struct {
|
||||
insertKeyBackupStmt *sql.Stmt
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ import (
|
|||
)
|
||||
|
||||
const loginTokenSchema = `
|
||||
CREATE TABLE IF NOT EXISTS login_tokens (
|
||||
CREATE TABLE IF NOT EXISTS userapi_login_tokens (
|
||||
-- The random value of the token issued to a user
|
||||
token TEXT NOT NULL PRIMARY KEY,
|
||||
-- When the token expires
|
||||
|
|
@ -37,17 +37,17 @@ CREATE TABLE IF NOT EXISTS login_tokens (
|
|||
);
|
||||
|
||||
-- This index allows efficient garbage collection of expired tokens.
|
||||
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at);
|
||||
CREATE INDEX IF NOT EXISTS userapi_login_tokens_expiration_idx ON userapi_login_tokens(token_expires_at);
|
||||
`
|
||||
|
||||
const insertLoginTokenSQL = "" +
|
||||
"INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO userapi_login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
|
||||
|
||||
const deleteLoginTokenSQL = "" +
|
||||
"DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"
|
||||
"DELETE FROM userapi_login_tokens WHERE token = $1 OR token_expires_at <= $2"
|
||||
|
||||
const selectLoginTokenSQL = "" +
|
||||
"SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"
|
||||
"SELECT user_id FROM userapi_login_tokens WHERE token = $1 AND token_expires_at > $2"
|
||||
|
||||
type loginTokenStatements struct {
|
||||
insertStmt *sql.Stmt
|
||||
|
|
@ -78,7 +78,7 @@ func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx
|
|||
// deleteByToken removes the named token.
|
||||
//
|
||||
// As a simple way to garbage-collect stale tokens, we also remove all expired tokens.
|
||||
// The login_tokens_expiration_idx index should make that efficient.
|
||||
// The userapi_login_tokens_expiration_idx index should make that efficient.
|
||||
func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteStmt)
|
||||
res, err := stmt.ExecContext(ctx, token, time.Now().UTC())
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import (
|
|||
|
||||
const openIDTokenSchema = `
|
||||
-- Stores data about openid tokens issued for accounts.
|
||||
CREATE TABLE IF NOT EXISTS open_id_tokens (
|
||||
CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
|
||||
-- The value of the token issued to a user
|
||||
token TEXT NOT NULL PRIMARY KEY,
|
||||
-- The Matrix user ID for this account
|
||||
|
|
@ -24,10 +24,10 @@ CREATE TABLE IF NOT EXISTS open_id_tokens (
|
|||
`
|
||||
|
||||
const insertOpenIDTokenSQL = "" +
|
||||
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
||||
|
||||
const selectOpenIDTokenSQL = "" +
|
||||
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
|
||||
"SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
|
||||
|
||||
type openIDTokenStatements struct {
|
||||
insertTokenStmt *sql.Stmt
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ import (
|
|||
|
||||
const profilesSchema = `
|
||||
-- Stores data about accounts profiles.
|
||||
CREATE TABLE IF NOT EXISTS account_profiles (
|
||||
CREATE TABLE IF NOT EXISTS userapi_profiles (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
-- The display name for this account
|
||||
|
|
@ -38,19 +38,19 @@ CREATE TABLE IF NOT EXISTS account_profiles (
|
|||
`
|
||||
|
||||
const insertProfileSQL = "" +
|
||||
"INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
|
||||
|
||||
const selectProfileByLocalpartSQL = "" +
|
||||
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
|
||||
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
|
||||
|
||||
const setAvatarURLSQL = "" +
|
||||
"UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
|
||||
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2"
|
||||
|
||||
const setDisplayNameSQL = "" +
|
||||
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
|
||||
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2"
|
||||
|
||||
const selectProfilesBySearchSQL = "" +
|
||||
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||
|
||||
type profilesStatements struct {
|
||||
serverNoticesLocalpart string
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ CREATE INDEX IF NOT EXISTS userapi_daily_visits_localpart_timestamp_idx ON usera
|
|||
|
||||
const countUsersLastSeenAfterSQL = "" +
|
||||
"SELECT COUNT(*) FROM (" +
|
||||
" SELECT localpart FROM device_devices WHERE last_seen_ts > $1 " +
|
||||
" SELECT localpart FROM userapi_devices WHERE last_seen_ts > $1 " +
|
||||
" GROUP BY localpart" +
|
||||
" ) u"
|
||||
|
||||
|
|
@ -62,7 +62,7 @@ R30Users counts the number of 30 day retained users, defined as:
|
|||
const countR30UsersSQL = `
|
||||
SELECT platform, COUNT(*) FROM (
|
||||
SELECT users.localpart, platform, users.created_ts, MAX(uip.last_seen_ts)
|
||||
FROM account_accounts users
|
||||
FROM userapi_accounts users
|
||||
INNER JOIN
|
||||
(SELECT
|
||||
localpart, last_seen_ts,
|
||||
|
|
@ -75,7 +75,7 @@ SELECT platform, COUNT(*) FROM (
|
|||
ELSE 'unknown'
|
||||
END
|
||||
AS platform
|
||||
FROM device_devices
|
||||
FROM userapi_devices
|
||||
) uip
|
||||
ON users.localpart = uip.localpart
|
||||
AND users.account_type <> 4
|
||||
|
|
@ -121,7 +121,7 @@ GROUP BY client_type
|
|||
`
|
||||
|
||||
const countUserByAccountTypeSQL = `
|
||||
SELECT COUNT(*) FROM account_accounts WHERE account_type = ANY($1)
|
||||
SELECT COUNT(*) FROM userapi_accounts WHERE account_type = ANY($1)
|
||||
`
|
||||
|
||||
// $1 = All non guest AccountType IDs
|
||||
|
|
@ -134,7 +134,7 @@ SELECT user_type, COUNT(*) AS count FROM (
|
|||
WHEN account_type = $2 AND appservice_id IS NULL THEN 'guest'
|
||||
WHEN account_type = ANY($1) AND appservice_id IS NOT NULL THEN 'bridged'
|
||||
END AS user_type
|
||||
FROM account_accounts
|
||||
FROM userapi_accounts
|
||||
WHERE created_ts > $3
|
||||
) AS t GROUP BY user_type
|
||||
`
|
||||
|
|
@ -143,14 +143,14 @@ SELECT user_type, COUNT(*) AS count FROM (
|
|||
const updateUserDailyVisitsSQL = `
|
||||
INSERT INTO userapi_daily_visits(localpart, device_id, timestamp, user_agent)
|
||||
SELECT u.localpart, u.device_id, $1, MAX(u.user_agent)
|
||||
FROM device_devices AS u
|
||||
FROM userapi_devices AS u
|
||||
LEFT JOIN (
|
||||
SELECT localpart, device_id, timestamp FROM userapi_daily_visits
|
||||
WHERE timestamp = $1
|
||||
) udv
|
||||
ON u.localpart = udv.localpart AND u.device_id = udv.device_id
|
||||
INNER JOIN device_devices d ON d.localpart = u.localpart
|
||||
INNER JOIN account_accounts a ON a.localpart = u.localpart
|
||||
INNER JOIN userapi_devices d ON d.localpart = u.localpart
|
||||
INNER JOIN userapi_accounts a ON a.localpart = u.localpart
|
||||
WHERE $2 <= d.last_seen_ts AND d.last_seen_ts < $3
|
||||
AND a.account_type in (1, 3)
|
||||
GROUP BY u.localpart, u.device_id
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/shared"
|
||||
|
||||
// Import the postgres database driver.
|
||||
|
|
@ -36,6 +37,16 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
return nil, err
|
||||
}
|
||||
|
||||
m := sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "userapi: rename tables",
|
||||
Up: deltas.UpRenameTables,
|
||||
Down: deltas.DownRenameTables,
|
||||
})
|
||||
if err = m.Up(base.Context()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accountDataTable, err := NewPostgresAccountDataTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ import (
|
|||
|
||||
const threepidSchema = `
|
||||
-- Stores data about third party identifiers
|
||||
CREATE TABLE IF NOT EXISTS account_threepid (
|
||||
CREATE TABLE IF NOT EXISTS userapi_threepids (
|
||||
-- The third party identifier
|
||||
threepid TEXT NOT NULL,
|
||||
-- The 3PID medium
|
||||
|
|
@ -37,20 +37,20 @@ CREATE TABLE IF NOT EXISTS account_threepid (
|
|||
PRIMARY KEY(threepid, medium)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart);
|
||||
CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart);
|
||||
`
|
||||
|
||||
const selectLocalpartForThreePIDSQL = "" +
|
||||
"SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
||||
"SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
||||
|
||||
const selectThreePIDsForLocalpartSQL = "" +
|
||||
"SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
|
||||
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1"
|
||||
|
||||
const insertThreePIDSQL = "" +
|
||||
"INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)"
|
||||
|
||||
const deleteThreePIDSQL = "" +
|
||||
"DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
||||
"DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
||||
|
||||
type threepidStatements struct {
|
||||
selectLocalpartForThreePIDStmt *sql.Stmt
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ import (
|
|||
|
||||
const accountDataSchema = `
|
||||
-- Stores data about accounts data.
|
||||
CREATE TABLE IF NOT EXISTS account_data (
|
||||
CREATE TABLE IF NOT EXISTS userapi_account_datas (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL,
|
||||
-- The room ID for this data (empty string if not specific to a room)
|
||||
|
|
@ -40,15 +40,15 @@ CREATE TABLE IF NOT EXISTS account_data (
|
|||
`
|
||||
|
||||
const insertAccountDataSQL = `
|
||||
INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
|
||||
INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
|
||||
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
|
||||
`
|
||||
|
||||
const selectAccountDataSQL = "" +
|
||||
"SELECT room_id, type, content FROM account_data WHERE localpart = $1"
|
||||
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1"
|
||||
|
||||
const selectAccountDataByTypeSQL = "" +
|
||||
"SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
|
||||
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3"
|
||||
|
||||
type accountDataStatements struct {
|
||||
db *sql.DB
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ import (
|
|||
|
||||
const accountsSchema = `
|
||||
-- Stores data about accounts.
|
||||
CREATE TABLE IF NOT EXISTS account_accounts (
|
||||
CREATE TABLE IF NOT EXISTS userapi_accounts (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
-- When this account was first created, as a unix timestamp (ms resolution).
|
||||
|
|
@ -51,22 +51,22 @@ CREATE TABLE IF NOT EXISTS account_accounts (
|
|||
`
|
||||
|
||||
const insertAccountSQL = "" +
|
||||
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
|
||||
"INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
|
||||
|
||||
const updatePasswordSQL = "" +
|
||||
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||
|
||||
const deactivateAccountSQL = "" +
|
||||
"UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1"
|
||||
"UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1"
|
||||
|
||||
const selectAccountByLocalpartSQL = "" +
|
||||
"SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1"
|
||||
"SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1"
|
||||
|
||||
const selectPasswordHashSQL = "" +
|
||||
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
||||
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
||||
|
||||
const selectNewNumericLocalpartSQL = "" +
|
||||
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM account_accounts WHERE CAST(localpart AS INT) <> 0"
|
||||
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0"
|
||||
|
||||
type accountsStatements struct {
|
||||
db *sql.DB
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ import (
|
|||
|
||||
func UpIsActive(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
ALTER TABLE account_accounts RENAME TO account_accounts_tmp;
|
||||
CREATE TABLE account_accounts (
|
||||
ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
|
||||
CREATE TABLE userapi_accounts (
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
created_ts BIGINT NOT NULL,
|
||||
password_hash TEXT,
|
||||
|
|
@ -17,13 +17,13 @@ CREATE TABLE account_accounts (
|
|||
is_deactivated BOOLEAN DEFAULT 0
|
||||
);
|
||||
INSERT
|
||||
INTO account_accounts (
|
||||
INTO userapi_accounts (
|
||||
localpart, created_ts, password_hash, appservice_id
|
||||
) SELECT
|
||||
localpart, created_ts, password_hash, appservice_id
|
||||
FROM account_accounts_tmp
|
||||
FROM userapi_accounts_tmp
|
||||
;
|
||||
DROP TABLE account_accounts_tmp;`)
|
||||
DROP TABLE userapi_accounts_tmp;`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
|
|
@ -32,21 +32,21 @@ DROP TABLE account_accounts_tmp;`)
|
|||
|
||||
func DownIsActive(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
ALTER TABLE account_accounts RENAME TO account_accounts_tmp;
|
||||
CREATE TABLE account_accounts (
|
||||
ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
|
||||
CREATE TABLE userapi_accounts (
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
created_ts BIGINT NOT NULL,
|
||||
password_hash TEXT,
|
||||
appservice_id TEXT
|
||||
);
|
||||
INSERT
|
||||
INTO account_accounts (
|
||||
INTO userapi_accounts (
|
||||
localpart, created_ts, password_hash, appservice_id
|
||||
) SELECT
|
||||
localpart, created_ts, password_hash, appservice_id
|
||||
FROM account_accounts_tmp
|
||||
FROM userapi_accounts_tmp
|
||||
;
|
||||
DROP TABLE account_accounts_tmp;`)
|
||||
DROP TABLE userapi_accounts_tmp;`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ import (
|
|||
|
||||
func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
ALTER TABLE device_devices RENAME TO device_devices_tmp;
|
||||
CREATE TABLE device_devices (
|
||||
ALTER TABLE userapi_devices RENAME TO userapi_devices_tmp;
|
||||
CREATE TABLE userapi_devices (
|
||||
access_token TEXT PRIMARY KEY,
|
||||
session_id INTEGER,
|
||||
device_id TEXT ,
|
||||
|
|
@ -22,12 +22,12 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
|
|||
UNIQUE (localpart, device_id)
|
||||
);
|
||||
INSERT
|
||||
INTO device_devices (
|
||||
INTO userapi_devices (
|
||||
access_token, session_id, device_id, localpart, created_ts, display_name, last_seen_ts, ip, user_agent
|
||||
) SELECT
|
||||
access_token, session_id, device_id, localpart, created_ts, display_name, created_ts, '', ''
|
||||
FROM device_devices_tmp;
|
||||
DROP TABLE device_devices_tmp;`)
|
||||
FROM userapi_devices_tmp;
|
||||
DROP TABLE userapi_devices_tmp;`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
|
|
@ -36,8 +36,8 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
|
|||
|
||||
func DownLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
ALTER TABLE device_devices RENAME TO device_devices_tmp;
|
||||
CREATE TABLE IF NOT EXISTS device_devices (
|
||||
ALTER TABLE userapi_devices RENAME TO userapi_devices_tmp;
|
||||
CREATE TABLE IF NOT EXISTS userapi_devices (
|
||||
access_token TEXT PRIMARY KEY,
|
||||
session_id INTEGER,
|
||||
device_id TEXT ,
|
||||
|
|
@ -47,12 +47,12 @@ CREATE TABLE IF NOT EXISTS device_devices (
|
|||
UNIQUE (localpart, device_id)
|
||||
);
|
||||
INSERT
|
||||
INTO device_devices (
|
||||
INTO userapi_devices (
|
||||
access_token, session_id, device_id, localpart, created_ts, display_name
|
||||
) SELECT
|
||||
access_token, session_id, device_id, localpart, created_ts, display_name
|
||||
FROM device_devices_tmp;
|
||||
DROP TABLE device_devices_tmp;`)
|
||||
FROM userapi_devices_tmp;
|
||||
DROP TABLE userapi_devices_tmp;`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ import (
|
|||
func UpAddAccountType(ctx context.Context, tx *sql.Tx) error {
|
||||
// initially set every account to useraccount, change appservice and guest accounts afterwards
|
||||
// (user = 1, guest = 2, admin = 3, appservice = 4)
|
||||
_, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts RENAME TO account_accounts_tmp;
|
||||
CREATE TABLE account_accounts (
|
||||
_, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
|
||||
CREATE TABLE userapi_accounts (
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
created_ts BIGINT NOT NULL,
|
||||
password_hash TEXT,
|
||||
|
|
@ -19,15 +19,15 @@ CREATE TABLE account_accounts (
|
|||
account_type INTEGER NOT NULL
|
||||
);
|
||||
INSERT
|
||||
INTO account_accounts (
|
||||
INTO userapi_accounts (
|
||||
localpart, created_ts, password_hash, appservice_id, account_type
|
||||
) SELECT
|
||||
localpart, created_ts, password_hash, appservice_id, 1
|
||||
FROM account_accounts_tmp
|
||||
FROM userapi_accounts_tmp
|
||||
;
|
||||
UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> '';
|
||||
UPDATE account_accounts SET account_type = 2 WHERE localpart GLOB '[0-9]*';
|
||||
DROP TABLE account_accounts_tmp;`)
|
||||
UPDATE userapi_accounts SET account_type = 4 WHERE appservice_id <> '';
|
||||
UPDATE userapi_accounts SET account_type = 2 WHERE localpart GLOB '[0-9]*';
|
||||
DROP TABLE userapi_accounts_tmp;`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add column: %w", err)
|
||||
}
|
||||
|
|
@ -35,7 +35,7 @@ DROP TABLE account_accounts_tmp;`)
|
|||
}
|
||||
|
||||
func DownAddAccountType(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts DROP COLUMN account_type;`)
|
||||
_, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts DROP COLUMN account_type;`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
|
|
|
|||
109
userapi/storage/sqlite3/deltas/2022101711000000_rename_tables.go
Normal file
109
userapi/storage/sqlite3/deltas/2022101711000000_rename_tables.go
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
package deltas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var renameTableMappings = map[string]string{
|
||||
"account_accounts": "userapi_accounts",
|
||||
"account_data": "userapi_account_datas",
|
||||
"device_devices": "userapi_devices",
|
||||
"account_e2e_room_keys": "userapi_key_backups",
|
||||
"account_e2e_room_keys_versions": "userapi_key_backup_versions",
|
||||
"login_tokens": "userapi_login_tokens",
|
||||
"open_id_tokens": "userapi_openid_tokens",
|
||||
"account_profiles": "userapi_profiles",
|
||||
"account_threepid": "userapi_threepids",
|
||||
}
|
||||
|
||||
var renameIndicesMappings = map[string]string{
|
||||
"device_localpart_id_idx": "userapi_device_localpart_id_idx",
|
||||
"e2e_room_keys_idx": "userapi_key_backups_idx",
|
||||
"e2e_room_keys_versions_idx": "userapi_key_backups_versions_idx",
|
||||
"account_e2e_room_keys_versions_idx": "userapi_key_backup_versions_idx",
|
||||
"login_tokens_expiration_idx": "userapi_login_tokens_expiration_idx",
|
||||
"account_threepid_localpart": "userapi_threepid_idx",
|
||||
}
|
||||
|
||||
func UpRenameTables(ctx context.Context, tx *sql.Tx) error {
|
||||
for old, new := range renameTableMappings {
|
||||
// SQLite has no "IF EXISTS" so check if the table exists.
|
||||
var name string
|
||||
if err := tx.QueryRowContext(
|
||||
ctx, "SELECT name FROM sqlite_schema WHERE type = 'table' AND name = $1;", old,
|
||||
).Scan(&name); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
q := fmt.Sprintf(
|
||||
"ALTER TABLE %s RENAME TO %s;", old, new,
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("rename table %q to %q error: %w", old, new, err)
|
||||
}
|
||||
}
|
||||
for old, new := range renameIndicesMappings {
|
||||
var query string
|
||||
if err := tx.QueryRowContext(
|
||||
ctx, "SELECT sql FROM sqlite_schema WHERE type = 'index' AND name = $1;", old,
|
||||
).Scan(&query); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
query = strings.Replace(query, old, new, 1)
|
||||
if _, err := tx.ExecContext(ctx, fmt.Sprintf("DROP INDEX %s;", old)); err != nil {
|
||||
return fmt.Errorf("drop index %q to %q error: %w", old, new, err)
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx, query); err != nil {
|
||||
return fmt.Errorf("recreate index %q to %q error: %w", old, new, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DownRenameTables(ctx context.Context, tx *sql.Tx) error {
|
||||
for old, new := range renameTableMappings {
|
||||
// SQLite has no "IF EXISTS" so check if the table exists.
|
||||
var name string
|
||||
if err := tx.QueryRowContext(
|
||||
ctx, "SELECT name FROM sqlite_schema WHERE type = 'table' AND name = $1;", new,
|
||||
).Scan(&name); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
q := fmt.Sprintf(
|
||||
"ALTER TABLE %s RENAME TO %s;", new, old,
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("rename table %q to %q error: %w", new, old, err)
|
||||
}
|
||||
}
|
||||
for old, new := range renameIndicesMappings {
|
||||
var query string
|
||||
if err := tx.QueryRowContext(
|
||||
ctx, "SELECT sql FROM sqlite_schema WHERE type = 'index' AND name = $1;", new,
|
||||
).Scan(&query); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
query = strings.Replace(query, new, old, 1)
|
||||
if _, err := tx.ExecContext(ctx, fmt.Sprintf("DROP INDEX %s;", new)); err != nil {
|
||||
return fmt.Errorf("drop index %q to %q error: %w", new, old, err)
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx, query); err != nil {
|
||||
return fmt.Errorf("recreate index %q to %q error: %w", new, old, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -35,7 +35,7 @@ const devicesSchema = `
|
|||
-- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
|
||||
|
||||
-- Stores data about devices.
|
||||
CREATE TABLE IF NOT EXISTS device_devices (
|
||||
CREATE TABLE IF NOT EXISTS userapi_devices (
|
||||
access_token TEXT PRIMARY KEY,
|
||||
session_id INTEGER,
|
||||
device_id TEXT ,
|
||||
|
|
@ -51,38 +51,38 @@ CREATE TABLE IF NOT EXISTS device_devices (
|
|||
`
|
||||
|
||||
const insertDeviceSQL = "" +
|
||||
"INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
|
||||
"INSERT INTO userapi_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
|
||||
|
||||
const selectDevicesCountSQL = "" +
|
||||
"SELECT COUNT(access_token) FROM device_devices"
|
||||
"SELECT COUNT(access_token) FROM userapi_devices"
|
||||
|
||||
const selectDeviceByTokenSQL = "" +
|
||||
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
|
||||
"SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1"
|
||||
|
||||
const selectDeviceByIDSQL = "" +
|
||||
"SELECT display_name, last_seen_ts, ip FROM device_devices WHERE localpart = $1 and device_id = $2"
|
||||
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2"
|
||||
|
||||
const selectDevicesByLocalpartSQL = "" +
|
||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
|
||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
|
||||
|
||||
const updateDeviceNameSQL = "" +
|
||||
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||
|
||||
const deleteDeviceSQL = "" +
|
||||
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
|
||||
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2"
|
||||
|
||||
const deleteDevicesByLocalpartSQL = "" +
|
||||
"DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
|
||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
|
||||
|
||||
const deleteDevicesSQL = "" +
|
||||
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
|
||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id IN ($2)"
|
||||
|
||||
const selectDevicesByIDSQL = "" +
|
||||
"SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
|
||||
"SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
|
||||
|
||||
const updateDeviceLastSeen = "" +
|
||||
"UPDATE device_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
|
||||
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
|
||||
|
||||
type devicesStatements struct {
|
||||
db *sql.DB
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ import (
|
|||
)
|
||||
|
||||
const keyBackupTableSchema = `
|
||||
CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
|
||||
CREATE TABLE IF NOT EXISTS userapi_key_backups (
|
||||
user_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
session_id TEXT NOT NULL,
|
||||
|
|
@ -37,31 +37,31 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
|
|||
is_verified BOOLEAN NOT NULL,
|
||||
session_data TEXT NOT NULL
|
||||
);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version);
|
||||
CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON userapi_key_backups(user_id, room_id, session_id, version);
|
||||
CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON userapi_key_backups(user_id, version);
|
||||
`
|
||||
|
||||
const insertBackupKeySQL = "" +
|
||||
"INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " +
|
||||
"INSERT INTO userapi_key_backups(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " +
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
|
||||
|
||||
const updateBackupKeySQL = "" +
|
||||
"UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " +
|
||||
"UPDATE userapi_key_backups SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " +
|
||||
"WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8"
|
||||
|
||||
const countKeysSQL = "" +
|
||||
"SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2"
|
||||
"SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2"
|
||||
|
||||
const selectKeysSQL = "" +
|
||||
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
|
||||
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
|
||||
"WHERE user_id = $1 AND version = $2"
|
||||
|
||||
const selectKeysByRoomIDSQL = "" +
|
||||
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
|
||||
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
|
||||
"WHERE user_id = $1 AND version = $2 AND room_id = $3"
|
||||
|
||||
const selectKeysByRoomIDAndSessionIDSQL = "" +
|
||||
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
|
||||
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
|
||||
"WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4"
|
||||
|
||||
type keyBackupStatements struct {
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ import (
|
|||
|
||||
const keyBackupVersionTableSchema = `
|
||||
-- the metadata for each generation of encrypted e2e session backups
|
||||
CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions (
|
||||
CREATE TABLE IF NOT EXISTS userapi_key_backup_versions (
|
||||
user_id TEXT NOT NULL,
|
||||
-- this means no 2 users will ever have the same version of e2e session backups which strictly
|
||||
-- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1.
|
||||
|
|
@ -38,26 +38,26 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions (
|
|||
deleted INTEGER DEFAULT 0 NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_key_backup_versions_idx ON userapi_key_backup_versions(user_id, version);
|
||||
`
|
||||
|
||||
const insertKeyBackupSQL = "" +
|
||||
"INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version"
|
||||
"INSERT INTO userapi_key_backup_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version"
|
||||
|
||||
const updateKeyBackupAuthDataSQL = "" +
|
||||
"UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3"
|
||||
"UPDATE userapi_key_backup_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3"
|
||||
|
||||
const updateKeyBackupETagSQL = "" +
|
||||
"UPDATE account_e2e_room_keys_versions SET etag = $1 WHERE user_id = $2 AND version = $3"
|
||||
"UPDATE userapi_key_backup_versions SET etag = $1 WHERE user_id = $2 AND version = $3"
|
||||
|
||||
const deleteKeyBackupSQL = "" +
|
||||
"UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2"
|
||||
"UPDATE userapi_key_backup_versions SET deleted=1 WHERE user_id = $1 AND version = $2"
|
||||
|
||||
const selectKeyBackupSQL = "" +
|
||||
"SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2"
|
||||
"SELECT algorithm, auth_data, etag, deleted FROM userapi_key_backup_versions WHERE user_id = $1 AND version = $2"
|
||||
|
||||
const selectLatestVersionSQL = "" +
|
||||
"SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1"
|
||||
"SELECT MAX(version) FROM userapi_key_backup_versions WHERE user_id = $1"
|
||||
|
||||
type keyBackupVersionStatements struct {
|
||||
insertKeyBackupStmt *sql.Stmt
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ type loginTokenStatements struct {
|
|||
}
|
||||
|
||||
const loginTokenSchema = `
|
||||
CREATE TABLE IF NOT EXISTS login_tokens (
|
||||
CREATE TABLE IF NOT EXISTS userapi_login_tokens (
|
||||
-- The random value of the token issued to a user
|
||||
token TEXT NOT NULL PRIMARY KEY,
|
||||
-- When the token expires
|
||||
|
|
@ -43,17 +43,17 @@ CREATE TABLE IF NOT EXISTS login_tokens (
|
|||
);
|
||||
|
||||
-- This index allows efficient garbage collection of expired tokens.
|
||||
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at);
|
||||
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON userapi_login_tokens(token_expires_at);
|
||||
`
|
||||
|
||||
const insertLoginTokenSQL = "" +
|
||||
"INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO userapi_login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
|
||||
|
||||
const deleteLoginTokenSQL = "" +
|
||||
"DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"
|
||||
"DELETE FROM userapi_login_tokens WHERE token = $1 OR token_expires_at <= $2"
|
||||
|
||||
const selectLoginTokenSQL = "" +
|
||||
"SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"
|
||||
"SELECT user_id FROM userapi_login_tokens WHERE token = $1 AND token_expires_at > $2"
|
||||
|
||||
func NewSQLiteLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) {
|
||||
s := &loginTokenStatements{}
|
||||
|
|
@ -78,7 +78,7 @@ func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx
|
|||
// deleteByToken removes the named token.
|
||||
//
|
||||
// As a simple way to garbage-collect stale tokens, we also remove all expired tokens.
|
||||
// The login_tokens_expiration_idx index should make that efficient.
|
||||
// The userapi_login_tokens_expiration_idx index should make that efficient.
|
||||
func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteStmt)
|
||||
res, err := stmt.ExecContext(ctx, token, time.Now().UTC())
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import (
|
|||
|
||||
const openIDTokenSchema = `
|
||||
-- Stores data about accounts.
|
||||
CREATE TABLE IF NOT EXISTS open_id_tokens (
|
||||
CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
|
||||
-- The value of the token issued to a user
|
||||
token TEXT NOT NULL PRIMARY KEY,
|
||||
-- The Matrix user ID for this account
|
||||
|
|
@ -24,10 +24,10 @@ CREATE TABLE IF NOT EXISTS open_id_tokens (
|
|||
`
|
||||
|
||||
const insertOpenIDTokenSQL = "" +
|
||||
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
||||
|
||||
const selectOpenIDTokenSQL = "" +
|
||||
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
|
||||
"SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
|
||||
|
||||
type openIDTokenStatements struct {
|
||||
db *sql.DB
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ import (
|
|||
|
||||
const profilesSchema = `
|
||||
-- Stores data about accounts profiles.
|
||||
CREATE TABLE IF NOT EXISTS account_profiles (
|
||||
CREATE TABLE IF NOT EXISTS userapi_profiles (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
-- The display name for this account
|
||||
|
|
@ -38,19 +38,19 @@ CREATE TABLE IF NOT EXISTS account_profiles (
|
|||
`
|
||||
|
||||
const insertProfileSQL = "" +
|
||||
"INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
|
||||
|
||||
const selectProfileByLocalpartSQL = "" +
|
||||
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
|
||||
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
|
||||
|
||||
const setAvatarURLSQL = "" +
|
||||
"UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
|
||||
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2"
|
||||
|
||||
const setDisplayNameSQL = "" +
|
||||
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
|
||||
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2"
|
||||
|
||||
const selectProfilesBySearchSQL = "" +
|
||||
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||
|
||||
type profilesStatements struct {
|
||||
db *sql.DB
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ CREATE INDEX IF NOT EXISTS userapi_daily_visits_localpart_timestamp_idx ON usera
|
|||
|
||||
const countUsersLastSeenAfterSQL = "" +
|
||||
"SELECT COUNT(*) FROM (" +
|
||||
" SELECT localpart FROM device_devices WHERE last_seen_ts > $1 " +
|
||||
" SELECT localpart FROM userapi_devices WHERE last_seen_ts > $1 " +
|
||||
" GROUP BY localpart" +
|
||||
" ) u"
|
||||
|
||||
|
|
@ -63,7 +63,7 @@ R30Users counts the number of 30 day retained users, defined as:
|
|||
const countR30UsersSQL = `
|
||||
SELECT platform, COUNT(*) FROM (
|
||||
SELECT users.localpart, platform, users.created_ts, MAX(uip.last_seen_ts)
|
||||
FROM account_accounts users
|
||||
FROM userapi_accounts users
|
||||
INNER JOIN
|
||||
(SELECT
|
||||
localpart, last_seen_ts,
|
||||
|
|
@ -76,7 +76,7 @@ SELECT platform, COUNT(*) FROM (
|
|||
ELSE 'unknown'
|
||||
END
|
||||
AS platform
|
||||
FROM device_devices
|
||||
FROM userapi_devices
|
||||
) uip
|
||||
ON users.localpart = uip.localpart
|
||||
AND users.account_type <> 4
|
||||
|
|
@ -126,7 +126,7 @@ GROUP BY client_type
|
|||
`
|
||||
|
||||
const countUserByAccountTypeSQL = `
|
||||
SELECT COUNT(*) FROM account_accounts WHERE account_type IN ($1)
|
||||
SELECT COUNT(*) FROM userapi_accounts WHERE account_type IN ($1)
|
||||
`
|
||||
|
||||
// $1 = Guest AccountType
|
||||
|
|
@ -139,7 +139,7 @@ SELECT user_type, COUNT(*) AS count FROM (
|
|||
WHEN account_type = $4 AND appservice_id IS NULL THEN 'guest'
|
||||
WHEN account_type IN ($5) AND appservice_id IS NOT NULL THEN 'bridged'
|
||||
END AS user_type
|
||||
FROM account_accounts
|
||||
FROM userapi_accounts
|
||||
WHERE created_ts > $8
|
||||
) AS t GROUP BY user_type
|
||||
`
|
||||
|
|
@ -148,14 +148,14 @@ SELECT user_type, COUNT(*) AS count FROM (
|
|||
const updateUserDailyVisitsSQL = `
|
||||
INSERT INTO userapi_daily_visits(localpart, device_id, timestamp, user_agent)
|
||||
SELECT u.localpart, u.device_id, $1, MAX(u.user_agent)
|
||||
FROM device_devices AS u
|
||||
FROM userapi_devices AS u
|
||||
LEFT JOIN (
|
||||
SELECT localpart, device_id, timestamp FROM userapi_daily_visits
|
||||
WHERE timestamp = $1
|
||||
) udv
|
||||
ON u.localpart = udv.localpart AND u.device_id = udv.device_id
|
||||
INNER JOIN device_devices d ON d.localpart = u.localpart
|
||||
INNER JOIN account_accounts a ON a.localpart = u.localpart
|
||||
INNER JOIN userapi_devices d ON d.localpart = u.localpart
|
||||
INNER JOIN userapi_accounts a ON a.localpart = u.localpart
|
||||
WHERE $2 <= d.last_seen_ts AND d.last_seen_ts < $3
|
||||
AND a.account_type in (1, 3)
|
||||
GROUP BY u.localpart, u.device_id
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/setup/config"
|
||||
|
||||
"github.com/matrix-org/dendrite/userapi/storage/shared"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
|
||||
)
|
||||
|
||||
// NewDatabase creates a new accounts and profiles database
|
||||
|
|
@ -34,6 +35,16 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
return nil, err
|
||||
}
|
||||
|
||||
m := sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "userapi: rename tables",
|
||||
Up: deltas.UpRenameTables,
|
||||
Down: deltas.DownRenameTables,
|
||||
})
|
||||
if err = m.Up(base.Context()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accountDataTable, err := NewSQLiteAccountDataTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ import (
|
|||
|
||||
const threepidSchema = `
|
||||
-- Stores data about third party identifiers
|
||||
CREATE TABLE IF NOT EXISTS account_threepid (
|
||||
CREATE TABLE IF NOT EXISTS userapi_threepids (
|
||||
-- The third party identifier
|
||||
threepid TEXT NOT NULL,
|
||||
-- The 3PID medium
|
||||
|
|
@ -38,20 +38,20 @@ CREATE TABLE IF NOT EXISTS account_threepid (
|
|||
PRIMARY KEY(threepid, medium)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart);
|
||||
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart);
|
||||
`
|
||||
|
||||
const selectLocalpartForThreePIDSQL = "" +
|
||||
"SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
||||
"SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
||||
|
||||
const selectThreePIDsForLocalpartSQL = "" +
|
||||
"SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
|
||||
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1"
|
||||
|
||||
const insertThreePIDSQL = "" +
|
||||
"INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)"
|
||||
|
||||
const deleteThreePIDSQL = "" +
|
||||
"DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
|
||||
"DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
||||
|
||||
type threepidStatements struct {
|
||||
db *sql.DB
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/test/testrig"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
|
|
@ -29,14 +30,18 @@ var (
|
|||
)
|
||||
|
||||
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
||||
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{
|
||||
db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server")
|
||||
if err != nil {
|
||||
t.Fatalf("NewUserAPIDatabase returned %s", err)
|
||||
}
|
||||
return db, close
|
||||
return db, func() {
|
||||
close()
|
||||
baseclose()
|
||||
}
|
||||
}
|
||||
|
||||
// Tests storing and getting account data
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ func mustUpdateDeviceLastSeen(
|
|||
timestamp time.Time,
|
||||
) {
|
||||
t.Helper()
|
||||
_, err := db.ExecContext(ctx, "UPDATE device_devices SET last_seen_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart)
|
||||
_, err := db.ExecContext(ctx, "UPDATE userapi_devices SET last_seen_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to update device last seen")
|
||||
}
|
||||
|
|
@ -119,7 +119,7 @@ func mustUserUpdateRegistered(
|
|||
localpart string,
|
||||
timestamp time.Time,
|
||||
) {
|
||||
_, err := db.ExecContext(ctx, "UPDATE account_accounts SET created_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart)
|
||||
_, err := db.ExecContext(ctx, "UPDATE userapi_accounts SET created_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to update device last seen")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ import (
|
|||
"github.com/gorilla/mux"
|
||||
"github.com/matrix-org/dendrite/internal/httputil"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/test/testrig"
|
||||
"github.com/matrix-org/dendrite/userapi"
|
||||
"github.com/matrix-org/dendrite/userapi/inthttp"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
|
@ -48,9 +49,9 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap
|
|||
if opts.loginTokenLifetime == 0 {
|
||||
opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond
|
||||
}
|
||||
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
|
||||
accountDB, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{
|
||||
accountDB, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "")
|
||||
if err != nil {
|
||||
|
|
@ -64,9 +65,12 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap
|
|||
}
|
||||
|
||||
return &internal.UserInternalAPI{
|
||||
DB: accountDB,
|
||||
ServerName: cfg.Matrix.ServerName,
|
||||
}, accountDB, close
|
||||
DB: accountDB,
|
||||
ServerName: cfg.Matrix.ServerName,
|
||||
}, accountDB, func() {
|
||||
close()
|
||||
baseclose()
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryProfile(t *testing.T) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue