mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-04 20:53:09 -06:00
Merge remote-tracking branch 'upstream/main'
This commit is contained in:
commit
a9808ae7e4
1
.github/workflows/dendrite.yml
vendored
1
.github/workflows/dendrite.yml
vendored
|
|
@ -250,6 +250,7 @@ jobs:
|
||||||
env:
|
env:
|
||||||
POSTGRES: ${{ matrix.postgres && 1}}
|
POSTGRES: ${{ matrix.postgres && 1}}
|
||||||
API: ${{ matrix.api && 1 }}
|
API: ${{ matrix.api && 1 }}
|
||||||
|
SYTEST_BRANCH: ${{ github.head_ref }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
- name: Run Sytest
|
- name: Run Sytest
|
||||||
|
|
|
||||||
29
CHANGES.md
29
CHANGES.md
|
|
@ -1,5 +1,34 @@
|
||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
## Dendrite 0.8.2 (2022-04-27)
|
||||||
|
|
||||||
|
### Features
|
||||||
|
|
||||||
|
* Lazy-loading has been added to the `/sync` endpoint, which should speed up syncs considerably
|
||||||
|
* Filtering has been added to the `/messages` endpoint
|
||||||
|
* The room summary now contains "heroes" (up to 5 users in the room) for clients to display when no room name is set
|
||||||
|
* The existing lazy-loading caches will now be used by `/messages` and `/context` so that member events will not be sent to clients more times than necessary
|
||||||
|
* The account data stream now uses the provided filters
|
||||||
|
* The built-in NATS Server has been updated to version 2.8.0
|
||||||
|
* The `/state` and `/state_ids` endpoints will now return `M_NOT_FOUND` for rejected events
|
||||||
|
* Repeated calls to the `/redact` endpoint will now be idempotent when a transaction ID is given
|
||||||
|
* Dendrite should now be able to run as a Windows service under Service Control Manager
|
||||||
|
|
||||||
|
### Fixes
|
||||||
|
|
||||||
|
* Fictitious presence updates will no longer be created for users which have not sent us presence updates, which should speed up complete syncs considerably
|
||||||
|
* Uploading cross-signing device signatures should now be more reliable, fixing a number of bugs with cross-signing
|
||||||
|
* All account data should now be sent properly on a complete sync, which should eliminate problems with client settings or key backups appearing to be missing
|
||||||
|
* Account data will now be limited correctly on incremental syncs, returning the stream position of the most recent update rather than the latest stream position
|
||||||
|
* Account data will not be sent for parted rooms, which should reduce the number of left/forgotten rooms reappearing in clients as empty rooms
|
||||||
|
* The TURN username hash has been fixed which should help to resolve some problems when using TURN for voice calls (contributed by [fcwoknhenuxdfiyv](https://github.com/fcwoknhenuxdfiyv))
|
||||||
|
* Push rules can no longer be modified using the account data endpoints
|
||||||
|
* Querying account availability should now work properly in polylith deployments
|
||||||
|
* A number of bugs with sync filters have been fixed
|
||||||
|
* A default sync filter will now be used if the request contains a filter ID that does not exist
|
||||||
|
* The `pushkey_ts` field is now using seconds instead of milliseconds
|
||||||
|
* A race condition when gracefully shutting down has been fixed, so JetStream should no longer cause the process to exit before other Dendrite components are finished shutting down
|
||||||
|
|
||||||
## Dendrite 0.8.1 (2022-04-07)
|
## Dendrite 0.8.1 (2022-04-07)
|
||||||
|
|
||||||
### Fixes
|
### Fixes
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
FROM docker.io/golang:1.17-alpine AS base
|
FROM docker.io/golang:1.18-alpine AS base
|
||||||
|
|
||||||
RUN apk --update --no-cache add bash build-base
|
RUN apk --update --no-cache add bash build-base
|
||||||
|
|
||||||
|
|
@ -23,4 +23,4 @@ COPY --from=base /build/bin/* /usr/bin/
|
||||||
VOLUME /etc/dendrite
|
VOLUME /etc/dendrite
|
||||||
WORKDIR /etc/dendrite
|
WORKDIR /etc/dendrite
|
||||||
|
|
||||||
ENTRYPOINT ["/usr/bin/dendrite-monolith-server"]
|
ENTRYPOINT ["/usr/bin/dendrite-monolith-server"]
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
FROM docker.io/golang:1.17-alpine AS base
|
FROM docker.io/golang:1.18-alpine AS base
|
||||||
|
|
||||||
RUN apk --update --no-cache add bash build-base
|
RUN apk --update --no-cache add bash build-base
|
||||||
|
|
||||||
|
|
@ -23,4 +23,4 @@ COPY --from=base /build/bin/* /usr/bin/
|
||||||
VOLUME /etc/dendrite
|
VOLUME /etc/dendrite
|
||||||
WORKDIR /etc/dendrite
|
WORKDIR /etc/dendrite
|
||||||
|
|
||||||
ENTRYPOINT ["/usr/bin/dendrite-polylith-multi"]
|
ENTRYPOINT ["/usr/bin/dendrite-polylith-multi"]
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/clientapi/httputil"
|
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
|
"github.com/matrix-org/dendrite/internal/transactions"
|
||||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
|
@ -40,12 +41,21 @@ type redactionResponse struct {
|
||||||
func SendRedaction(
|
func SendRedaction(
|
||||||
req *http.Request, device *userapi.Device, roomID, eventID string, cfg *config.ClientAPI,
|
req *http.Request, device *userapi.Device, roomID, eventID string, cfg *config.ClientAPI,
|
||||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||||
|
txnID *string,
|
||||||
|
txnCache *transactions.Cache,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
resErr := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID)
|
resErr := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID)
|
||||||
if resErr != nil {
|
if resErr != nil {
|
||||||
return *resErr
|
return *resErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if txnID != nil {
|
||||||
|
// Try to fetch response from transactionsCache
|
||||||
|
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok {
|
||||||
|
return *res
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ev := roomserverAPI.GetEvent(req.Context(), rsAPI, eventID)
|
ev := roomserverAPI.GetEvent(req.Context(), rsAPI, eventID)
|
||||||
if ev == nil {
|
if ev == nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
|
|
@ -124,10 +134,18 @@ func SendRedaction(
|
||||||
util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents")
|
util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
return util.JSONResponse{
|
|
||||||
|
res := util.JSONResponse{
|
||||||
Code: 200,
|
Code: 200,
|
||||||
JSON: redactionResponse{
|
JSON: redactionResponse{
|
||||||
EventID: e.EventID(),
|
EventID: e.EventID(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add response to transactionsCache
|
||||||
|
if txnID != nil {
|
||||||
|
txnCache.AddTransaction(device.AccessToken, *txnID, &res)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -479,7 +479,7 @@ func Setup(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI)
|
return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, nil, nil)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
v3mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}",
|
v3mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}",
|
||||||
|
|
@ -488,7 +488,8 @@ func Setup(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI)
|
txnID := vars["txnId"]
|
||||||
|
return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, &txnID, transactionsCache)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -75,7 +75,7 @@ func (r *FederationInternalAPI) PerformJoin(
|
||||||
seenSet := make(map[gomatrixserverlib.ServerName]bool)
|
seenSet := make(map[gomatrixserverlib.ServerName]bool)
|
||||||
var uniqueList []gomatrixserverlib.ServerName
|
var uniqueList []gomatrixserverlib.ServerName
|
||||||
for _, srv := range request.ServerNames {
|
for _, srv := range request.ServerNames {
|
||||||
if seenSet[srv] {
|
if seenSet[srv] || srv == r.cfg.Matrix.ServerName {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
seenSet[srv] = true
|
seenSet[srv] = true
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,7 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re
|
||||||
// this destination queue. We'll then be able to retrieve the PDU
|
// this destination queue. We'll then be able to retrieve the PDU
|
||||||
// later.
|
// later.
|
||||||
if err := oq.db.AssociatePDUWithDestination(
|
if err := oq.db.AssociatePDUWithDestination(
|
||||||
context.TODO(),
|
oq.process.Context(),
|
||||||
"", // TODO: remove this, as we don't need to persist the transaction ID
|
"", // TODO: remove this, as we don't need to persist the transaction ID
|
||||||
oq.destination, // the destination server name
|
oq.destination, // the destination server name
|
||||||
receipt, // NIDs from federationapi_queue_json table
|
receipt, // NIDs from federationapi_queue_json table
|
||||||
|
|
@ -122,7 +122,7 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share
|
||||||
// this destination queue. We'll then be able to retrieve the PDU
|
// this destination queue. We'll then be able to retrieve the PDU
|
||||||
// later.
|
// later.
|
||||||
if err := oq.db.AssociateEDUWithDestination(
|
if err := oq.db.AssociateEDUWithDestination(
|
||||||
context.TODO(),
|
oq.process.Context(),
|
||||||
oq.destination, // the destination server name
|
oq.destination, // the destination server name
|
||||||
receipt, // NIDs from federationapi_queue_json table
|
receipt, // NIDs from federationapi_queue_json table
|
||||||
event.Type,
|
event.Type,
|
||||||
|
|
@ -177,7 +177,7 @@ func (oq *destinationQueue) getPendingFromDatabase() {
|
||||||
// Check to see if there's anything to do for this server
|
// Check to see if there's anything to do for this server
|
||||||
// in the database.
|
// in the database.
|
||||||
retrieved := false
|
retrieved := false
|
||||||
ctx := context.Background()
|
ctx := oq.process.Context()
|
||||||
oq.pendingMutex.Lock()
|
oq.pendingMutex.Lock()
|
||||||
defer oq.pendingMutex.Unlock()
|
defer oq.pendingMutex.Unlock()
|
||||||
|
|
||||||
|
|
@ -271,6 +271,9 @@ func (oq *destinationQueue) backgroundSend() {
|
||||||
// restarted automatically the next time we have an event to
|
// restarted automatically the next time we have an event to
|
||||||
// send.
|
// send.
|
||||||
return
|
return
|
||||||
|
case <-oq.process.Context().Done():
|
||||||
|
// The parent process is shutting down, so stop.
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we are backing off this server then wait for the
|
// If we are backing off this server then wait for the
|
||||||
|
|
@ -420,13 +423,13 @@ func (oq *destinationQueue) nextTransaction(
|
||||||
// Clean up the transaction in the database.
|
// Clean up the transaction in the database.
|
||||||
if pduReceipts != nil {
|
if pduReceipts != nil {
|
||||||
//logrus.Infof("Cleaning PDUs %q", pduReceipt.String())
|
//logrus.Infof("Cleaning PDUs %q", pduReceipt.String())
|
||||||
if err = oq.db.CleanPDUs(context.Background(), oq.destination, pduReceipts); err != nil {
|
if err = oq.db.CleanPDUs(oq.process.Context(), oq.destination, pduReceipts); err != nil {
|
||||||
logrus.WithError(err).Errorf("Failed to clean PDUs for server %q", t.Destination)
|
logrus.WithError(err).Errorf("Failed to clean PDUs for server %q", t.Destination)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if eduReceipts != nil {
|
if eduReceipts != nil {
|
||||||
//logrus.Infof("Cleaning EDUs %q", eduReceipt.String())
|
//logrus.Infof("Cleaning EDUs %q", eduReceipt.String())
|
||||||
if err = oq.db.CleanEDUs(context.Background(), oq.destination, eduReceipts); err != nil {
|
if err = oq.db.CleanEDUs(oq.process.Context(), oq.destination, eduReceipts); err != nil {
|
||||||
logrus.WithError(err).Errorf("Failed to clean EDUs for server %q", t.Destination)
|
logrus.WithError(err).Errorf("Failed to clean EDUs for server %q", t.Destination)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@
|
||||||
package queue
|
package queue
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
@ -105,14 +104,14 @@ func NewOutgoingQueues(
|
||||||
// Look up which servers we have pending items for and then rehydrate those queues.
|
// Look up which servers we have pending items for and then rehydrate those queues.
|
||||||
if !disabled {
|
if !disabled {
|
||||||
serverNames := map[gomatrixserverlib.ServerName]struct{}{}
|
serverNames := map[gomatrixserverlib.ServerName]struct{}{}
|
||||||
if names, err := db.GetPendingPDUServerNames(context.Background()); err == nil {
|
if names, err := db.GetPendingPDUServerNames(process.Context()); err == nil {
|
||||||
for _, serverName := range names {
|
for _, serverName := range names {
|
||||||
serverNames[serverName] = struct{}{}
|
serverNames[serverName] = struct{}{}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.WithError(err).Error("Failed to get PDU server names for destination queue hydration")
|
log.WithError(err).Error("Failed to get PDU server names for destination queue hydration")
|
||||||
}
|
}
|
||||||
if names, err := db.GetPendingEDUServerNames(context.Background()); err == nil {
|
if names, err := db.GetPendingEDUServerNames(process.Context()); err == nil {
|
||||||
for _, serverName := range names {
|
for _, serverName := range names {
|
||||||
serverNames[serverName] = struct{}{}
|
serverNames[serverName] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
@ -215,7 +214,7 @@ func (oqs *OutgoingQueues) SendEvent(
|
||||||
// Check if any of the destinations are prohibited by server ACLs.
|
// Check if any of the destinations are prohibited by server ACLs.
|
||||||
for destination := range destmap {
|
for destination := range destmap {
|
||||||
if api.IsServerBannedFromRoom(
|
if api.IsServerBannedFromRoom(
|
||||||
context.TODO(),
|
oqs.process.Context(),
|
||||||
oqs.rsAPI,
|
oqs.rsAPI,
|
||||||
ev.RoomID(),
|
ev.RoomID(),
|
||||||
destination,
|
destination,
|
||||||
|
|
@ -238,7 +237,7 @@ func (oqs *OutgoingQueues) SendEvent(
|
||||||
return fmt.Errorf("json.Marshal: %w", err)
|
return fmt.Errorf("json.Marshal: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
nid, err := oqs.db.StoreJSON(context.TODO(), string(headeredJSON))
|
nid, err := oqs.db.StoreJSON(oqs.process.Context(), string(headeredJSON))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err)
|
return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -286,7 +285,7 @@ func (oqs *OutgoingQueues) SendEDU(
|
||||||
if result := gjson.GetBytes(e.Content, "room_id"); result.Exists() {
|
if result := gjson.GetBytes(e.Content, "room_id"); result.Exists() {
|
||||||
for destination := range destmap {
|
for destination := range destmap {
|
||||||
if api.IsServerBannedFromRoom(
|
if api.IsServerBannedFromRoom(
|
||||||
context.TODO(),
|
oqs.process.Context(),
|
||||||
oqs.rsAPI,
|
oqs.rsAPI,
|
||||||
result.Str,
|
result.Str,
|
||||||
destination,
|
destination,
|
||||||
|
|
@ -310,7 +309,7 @@ func (oqs *OutgoingQueues) SendEDU(
|
||||||
return fmt.Errorf("json.Marshal: %w", err)
|
return fmt.Errorf("json.Marshal: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
nid, err := oqs.db.StoreJSON(context.TODO(), string(ephemeralJSON))
|
nid, err := oqs.db.StoreJSON(oqs.process.Context(), string(ephemeralJSON))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err)
|
return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
2
go.mod
2
go.mod
|
|
@ -25,6 +25,7 @@ require (
|
||||||
github.com/h2non/filetype v1.1.3 // indirect
|
github.com/h2non/filetype v1.1.3 // indirect
|
||||||
github.com/hashicorp/golang-lru v0.5.4
|
github.com/hashicorp/golang-lru v0.5.4
|
||||||
github.com/juju/testing v0.0.0-20220203020004-a0ff61f03494 // indirect
|
github.com/juju/testing v0.0.0-20220203020004-a0ff61f03494 // indirect
|
||||||
|
github.com/kardianos/minwinsvc v1.0.0 // indirect
|
||||||
github.com/lib/pq v1.10.5
|
github.com/lib/pq v1.10.5
|
||||||
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
|
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
|
||||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
|
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
|
||||||
|
|
@ -46,6 +47,7 @@ require (
|
||||||
github.com/pressly/goose v2.7.0+incompatible
|
github.com/pressly/goose v2.7.0+incompatible
|
||||||
github.com/prometheus/client_golang v1.12.1
|
github.com/prometheus/client_golang v1.12.1
|
||||||
github.com/sirupsen/logrus v1.8.1
|
github.com/sirupsen/logrus v1.8.1
|
||||||
|
github.com/stretchr/testify v1.7.0
|
||||||
github.com/tidwall/gjson v1.14.0
|
github.com/tidwall/gjson v1.14.0
|
||||||
github.com/tidwall/sjson v1.2.4
|
github.com/tidwall/sjson v1.2.4
|
||||||
github.com/uber/jaeger-client-go v2.30.0+incompatible
|
github.com/uber/jaeger-client-go v2.30.0+incompatible
|
||||||
|
|
|
||||||
1
go.sum
1
go.sum
|
|
@ -721,6 +721,7 @@ github.com/julienschmidt/httprouter v1.1.1-0.20151013225520-77a895ad01eb/go.mod
|
||||||
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
|
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
|
||||||
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
|
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
|
||||||
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
|
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
|
||||||
|
github.com/kardianos/minwinsvc v1.0.0 h1:+JfAi8IBJna0jY2dJGZqi7o15z13JelFIklJCAENALA=
|
||||||
github.com/kardianos/minwinsvc v1.0.0/go.mod h1:Bgd0oc+D0Qo3bBytmNtyRKVlp85dAloLKhfxanPFFRc=
|
github.com/kardianos/minwinsvc v1.0.0/go.mod h1:Bgd0oc+D0Qo3bBytmNtyRKVlp85dAloLKhfxanPFFRc=
|
||||||
github.com/kataras/golog v0.0.10/go.mod h1:yJ8YKCmyL+nWjERB90Qwn+bdyBZsaQwU3bTVFgkFIp8=
|
github.com/kataras/golog v0.0.10/go.mod h1:yJ8YKCmyL+nWjERB90Qwn+bdyBZsaQwU3bTVFgkFIp8=
|
||||||
github.com/kataras/iris/v12 v12.1.8/go.mod h1:LMYy4VlP67TQ3Zgriz8RE2h2kMZV2SgMYbq3UhfoFmE=
|
github.com/kataras/iris/v12 v12.1.8/go.mod h1:LMYy4VlP67TQ3Zgriz8RE2h2kMZV2SgMYbq3UhfoFmE=
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ var build string
|
||||||
const (
|
const (
|
||||||
VersionMajor = 0
|
VersionMajor = 0
|
||||||
VersionMinor = 8
|
VersionMinor = 8
|
||||||
VersionPatch = 1
|
VersionPatch = 2
|
||||||
VersionTag = "" // example: "rc1"
|
VersionTag = "" // example: "rc1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -362,6 +362,13 @@ func (a *KeyInternalAPI) processSelfSignatures(
|
||||||
for targetKeyID, signature := range forTargetUserID {
|
for targetKeyID, signature := range forTargetUserID {
|
||||||
switch sig := signature.CrossSigningBody.(type) {
|
switch sig := signature.CrossSigningBody.(type) {
|
||||||
case *gomatrixserverlib.CrossSigningKey:
|
case *gomatrixserverlib.CrossSigningKey:
|
||||||
|
for keyID := range sig.Keys {
|
||||||
|
split := strings.SplitN(string(keyID), ":", 2)
|
||||||
|
if len(split) > 1 && gomatrixserverlib.KeyID(split[1]) == targetKeyID {
|
||||||
|
targetKeyID = keyID // contains the ed25519: or other scheme
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
for originUserID, forOriginUserID := range sig.Signatures {
|
for originUserID, forOriginUserID := range sig.Signatures {
|
||||||
for originKeyID, originSig := range forOriginUserID {
|
for originKeyID, originSig := range forOriginUserID {
|
||||||
if err := a.DB.StoreCrossSigningSigsForTarget(
|
if err := a.DB.StoreCrossSigningSigsForTarget(
|
||||||
|
|
|
||||||
|
|
@ -33,8 +33,10 @@ CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs (
|
||||||
target_user_id TEXT NOT NULL,
|
target_user_id TEXT NOT NULL,
|
||||||
target_key_id TEXT NOT NULL,
|
target_key_id TEXT NOT NULL,
|
||||||
signature TEXT NOT NULL,
|
signature TEXT NOT NULL,
|
||||||
PRIMARY KEY (origin_user_id, target_user_id, target_key_id)
|
PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
|
||||||
`
|
`
|
||||||
|
|
||||||
const selectCrossSigningSigsForTargetSQL = "" +
|
const selectCrossSigningSigsForTargetSQL = "" +
|
||||||
|
|
@ -44,7 +46,7 @@ const selectCrossSigningSigsForTargetSQL = "" +
|
||||||
const upsertCrossSigningSigsForTargetSQL = "" +
|
const upsertCrossSigningSigsForTargetSQL = "" +
|
||||||
"INSERT INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +
|
"INSERT INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +
|
||||||
" VALUES($1, $2, $3, $4, $5)" +
|
" VALUES($1, $2, $3, $4, $5)" +
|
||||||
" ON CONFLICT (origin_user_id, target_user_id, target_key_id) DO UPDATE SET (origin_key_id, signature) = ($2, $5)"
|
" ON CONFLICT (origin_user_id, origin_key_id, target_user_id, target_key_id) DO UPDATE SET signature = $5"
|
||||||
|
|
||||||
const deleteCrossSigningSigsForTargetSQL = "" +
|
const deleteCrossSigningSigsForTargetSQL = "" +
|
||||||
"DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2"
|
"DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package deltas
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func LoadFixCrossSigningSignatureIndexes(m *sqlutil.Migrations) {
|
||||||
|
m.AddMigration(UpFixCrossSigningSignatureIndexes, DownFixCrossSigningSignatureIndexes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpFixCrossSigningSignatureIndexes(tx *sql.Tx) error {
|
||||||
|
_, err := tx.Exec(`
|
||||||
|
ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey;
|
||||||
|
ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DownFixCrossSigningSignatureIndexes(tx *sql.Tx) error {
|
||||||
|
_, err := tx.Exec(`
|
||||||
|
ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey;
|
||||||
|
ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, target_user_id, target_key_id);
|
||||||
|
|
||||||
|
DROP INDEX IF EXISTS keyserver_cross_signing_sigs_idx;
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
@ -54,6 +54,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error)
|
||||||
}
|
}
|
||||||
m := sqlutil.NewMigrations()
|
m := sqlutil.NewMigrations()
|
||||||
deltas.LoadRefactorKeyChanges(m)
|
deltas.LoadRefactorKeyChanges(m)
|
||||||
|
deltas.LoadFixCrossSigningSignatureIndexes(m)
|
||||||
if err = m.RunDeltas(db, dbProperties); err != nil {
|
if err = m.RunDeltas(db, dbProperties); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -33,8 +33,10 @@ CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs (
|
||||||
target_user_id TEXT NOT NULL,
|
target_user_id TEXT NOT NULL,
|
||||||
target_key_id TEXT NOT NULL,
|
target_key_id TEXT NOT NULL,
|
||||||
signature TEXT NOT NULL,
|
signature TEXT NOT NULL,
|
||||||
PRIMARY KEY (origin_user_id, target_user_id, target_key_id)
|
PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
|
||||||
`
|
`
|
||||||
|
|
||||||
const selectCrossSigningSigsForTargetSQL = "" +
|
const selectCrossSigningSigsForTargetSQL = "" +
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,76 @@
|
||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package deltas
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func LoadFixCrossSigningSignatureIndexes(m *sqlutil.Migrations) {
|
||||||
|
m.AddMigration(UpFixCrossSigningSignatureIndexes, DownFixCrossSigningSignatureIndexes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpFixCrossSigningSignatureIndexes(tx *sql.Tx) error {
|
||||||
|
_, err := tx.Exec(`
|
||||||
|
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp (
|
||||||
|
origin_user_id TEXT NOT NULL,
|
||||||
|
origin_key_id TEXT NOT NULL,
|
||||||
|
target_user_id TEXT NOT NULL,
|
||||||
|
target_key_id TEXT NOT NULL,
|
||||||
|
signature TEXT NOT NULL,
|
||||||
|
PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)
|
||||||
|
SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs;
|
||||||
|
|
||||||
|
DROP TABLE keyserver_cross_signing_sigs;
|
||||||
|
ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs;
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DownFixCrossSigningSignatureIndexes(tx *sql.Tx) error {
|
||||||
|
_, err := tx.Exec(`
|
||||||
|
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp (
|
||||||
|
origin_user_id TEXT NOT NULL,
|
||||||
|
origin_key_id TEXT NOT NULL,
|
||||||
|
target_user_id TEXT NOT NULL,
|
||||||
|
target_key_id TEXT NOT NULL,
|
||||||
|
signature TEXT NOT NULL,
|
||||||
|
PRIMARY KEY (origin_user_id, target_user_id, target_key_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)
|
||||||
|
SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs;
|
||||||
|
|
||||||
|
DROP TABLE keyserver_cross_signing_sigs;
|
||||||
|
ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs;
|
||||||
|
|
||||||
|
DELETE INDEX IF EXISTS keyserver_cross_signing_sigs_idx;
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
@ -53,6 +53,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error)
|
||||||
|
|
||||||
m := sqlutil.NewMigrations()
|
m := sqlutil.NewMigrations()
|
||||||
deltas.LoadRefactorKeyChanges(m)
|
deltas.LoadRefactorKeyChanges(m)
|
||||||
|
deltas.LoadFixCrossSigningSignatureIndexes(m)
|
||||||
if err = m.RunDeltas(db, dbProperties); err != nil {
|
if err = m.RunDeltas(db, dbProperties); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
_ "net/http/pprof"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
@ -42,6 +43,7 @@ import (
|
||||||
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/kardianos/minwinsvc"
|
||||||
|
|
||||||
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
|
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
|
||||||
asinthttp "github.com/matrix-org/dendrite/appservice/inthttp"
|
asinthttp "github.com/matrix-org/dendrite/appservice/inthttp"
|
||||||
|
|
@ -55,8 +57,6 @@ import (
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
userapiinthttp "github.com/matrix-org/dendrite/userapi/inthttp"
|
userapiinthttp "github.com/matrix-org/dendrite/userapi/inthttp"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
_ "net/http/pprof"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// BaseDendrite is a base for creating new instances of dendrite. It parses
|
// BaseDendrite is a base for creating new instances of dendrite. It parses
|
||||||
|
|
@ -272,7 +272,7 @@ func (b *BaseDendrite) PushGatewayHTTPClient() pushgateway.Client {
|
||||||
// CreateAccountsDB creates a new instance of the accounts database. Should only
|
// CreateAccountsDB creates a new instance of the accounts database. Should only
|
||||||
// be called once per component.
|
// be called once per component.
|
||||||
func (b *BaseDendrite) CreateAccountsDB() userdb.Database {
|
func (b *BaseDendrite) CreateAccountsDB() userdb.Database {
|
||||||
db, err := userdb.NewDatabase(
|
db, err := userdb.NewUserAPIDatabase(
|
||||||
&b.Cfg.UserAPI.AccountDatabase,
|
&b.Cfg.UserAPI.AccountDatabase,
|
||||||
b.Cfg.Global.ServerName,
|
b.Cfg.Global.ServerName,
|
||||||
b.Cfg.UserAPI.BCryptCost,
|
b.Cfg.UserAPI.BCryptCost,
|
||||||
|
|
@ -346,6 +346,9 @@ func (b *BaseDendrite) SetupAndServeHTTP(
|
||||||
Addr: string(externalAddr),
|
Addr: string(externalAddr),
|
||||||
WriteTimeout: HTTPServerTimeout,
|
WriteTimeout: HTTPServerTimeout,
|
||||||
Handler: externalRouter,
|
Handler: externalRouter,
|
||||||
|
BaseContext: func(_ net.Listener) context.Context {
|
||||||
|
return b.ProcessContext.Context()
|
||||||
|
},
|
||||||
}
|
}
|
||||||
internalServ := externalServ
|
internalServ := externalServ
|
||||||
|
|
||||||
|
|
@ -361,6 +364,9 @@ func (b *BaseDendrite) SetupAndServeHTTP(
|
||||||
internalServ = &http.Server{
|
internalServ = &http.Server{
|
||||||
Addr: string(internalAddr),
|
Addr: string(internalAddr),
|
||||||
Handler: h2c.NewHandler(internalRouter, internalH2S),
|
Handler: h2c.NewHandler(internalRouter, internalH2S),
|
||||||
|
BaseContext: func(_ net.Listener) context.Context {
|
||||||
|
return b.ProcessContext.Context()
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -462,20 +468,22 @@ func (b *BaseDendrite) SetupAndServeHTTP(
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
minwinsvc.SetOnExit(b.ProcessContext.ShutdownDendrite)
|
||||||
<-b.ProcessContext.WaitForShutdown()
|
<-b.ProcessContext.WaitForShutdown()
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
logrus.Infof("Stopping HTTP listeners")
|
||||||
cancel()
|
_ = internalServ.Shutdown(context.Background())
|
||||||
|
_ = externalServ.Shutdown(context.Background())
|
||||||
_ = internalServ.Shutdown(ctx)
|
|
||||||
_ = externalServ.Shutdown(ctx)
|
|
||||||
logrus.Infof("Stopped HTTP listeners")
|
logrus.Infof("Stopped HTTP listeners")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BaseDendrite) WaitForShutdown() {
|
func (b *BaseDendrite) WaitForShutdown() {
|
||||||
sigs := make(chan os.Signal, 1)
|
sigs := make(chan os.Signal, 1)
|
||||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||||
<-sigs
|
select {
|
||||||
|
case <-sigs:
|
||||||
|
case <-b.ProcessContext.WaitForShutdown():
|
||||||
|
}
|
||||||
signal.Reset(syscall.SIGINT, syscall.SIGTERM)
|
signal.Reset(syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
logrus.Warnf("Shutdown signal received")
|
logrus.Warnf("Shutdown signal received")
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,16 @@ func JetStreamConsumer(
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
|
// If the parent context has given up then there's no point in
|
||||||
|
// carrying on doing anything, so stop the listener.
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
if err := sub.Unsubscribe(); err != nil {
|
||||||
|
logrus.WithContext(ctx).Warnf("Failed to unsubscribe %q", durable)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
// The context behaviour here is surprising — we supply a context
|
// The context behaviour here is surprising — we supply a context
|
||||||
// so that we can interrupt the fetch if we want, but NATS will still
|
// so that we can interrupt the fetch if we want, but NATS will still
|
||||||
// enforce its own deadline (roughly 5 seconds by default). Therefore
|
// enforce its own deadline (roughly 5 seconds by default). Therefore
|
||||||
|
|
@ -65,18 +75,18 @@ func JetStreamConsumer(
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
msg := msgs[0]
|
msg := msgs[0]
|
||||||
if err = msg.InProgress(); err != nil {
|
if err = msg.InProgress(nats.Context(ctx)); err != nil {
|
||||||
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err))
|
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err))
|
||||||
sentry.CaptureException(err)
|
sentry.CaptureException(err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if f(ctx, msg) {
|
if f(ctx, msg) {
|
||||||
if err = msg.AckSync(); err != nil {
|
if err = msg.AckSync(nats.Context(ctx)); err != nil {
|
||||||
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.AckSync: %w", err))
|
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.AckSync: %w", err))
|
||||||
sentry.CaptureException(err)
|
sentry.CaptureException(err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err = msg.Nak(); err != nil {
|
if err = msg.Nak(nats.Context(ctx)); err != nil {
|
||||||
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Nak: %w", err))
|
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Nak: %w", err))
|
||||||
sentry.CaptureException(err)
|
sentry.CaptureException(err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -44,6 +44,7 @@ func Prepare(process *process.ProcessContext, cfg *config.JetStream) (natsclient
|
||||||
StoreDir: string(cfg.StoragePath),
|
StoreDir: string(cfg.StoragePath),
|
||||||
NoSystemAccount: true,
|
NoSystemAccount: true,
|
||||||
MaxPayload: 16 * 1024 * 1024,
|
MaxPayload: 16 * 1024 * 1024,
|
||||||
|
NoSigs: true,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
|
|
||||||
|
|
@ -88,6 +88,11 @@ func (s *PresenceConsumer) Start() error {
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if presence == nil {
|
||||||
|
presence = &types.PresenceInternal{
|
||||||
|
UserID: userID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
deviceRes := api.QueryDevicesResponse{}
|
deviceRes := api.QueryDevicesResponse{}
|
||||||
if err = s.deviceAPI.QueryDevices(s.ctx, &api.QueryDevicesRequest{UserID: userID}, &deviceRes); err != nil {
|
if err = s.deviceAPI.QueryDevices(s.ctx, &api.QueryDevicesRequest{UserID: userID}, &deviceRes); err != nil {
|
||||||
|
|
@ -106,7 +111,9 @@ func (s *PresenceConsumer) Start() error {
|
||||||
|
|
||||||
m.Header.Set(jetstream.UserID, presence.UserID)
|
m.Header.Set(jetstream.UserID, presence.UserID)
|
||||||
m.Header.Set("presence", presence.ClientFields.Presence)
|
m.Header.Set("presence", presence.ClientFields.Presence)
|
||||||
m.Header.Set("status_msg", *presence.ClientFields.StatusMsg)
|
if presence.ClientFields.StatusMsg != nil {
|
||||||
|
m.Header.Set("status_msg", *presence.ClientFields.StatusMsg)
|
||||||
|
}
|
||||||
m.Header.Set("last_active_ts", strconv.Itoa(int(presence.LastActiveTS)))
|
m.Header.Set("last_active_ts", strconv.Itoa(int(presence.LastActiveTS)))
|
||||||
|
|
||||||
if err = msg.RespondMsg(m); err != nil {
|
if err = msg.RespondMsg(m); err != nil {
|
||||||
|
|
|
||||||
|
|
@ -44,8 +44,8 @@ func GetFilter(
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
filter, err := syncDB.GetFilter(req.Context(), localpart, filterID)
|
filter := gomatrixserverlib.DefaultFilter()
|
||||||
if err != nil {
|
if err := syncDB.GetFilter(req.Context(), &filter, localpart, filterID); err != nil {
|
||||||
//TODO better error handling. This error message is *probably* right,
|
//TODO better error handling. This error message is *probably* right,
|
||||||
// but if there are obscure db errors, this will also be returned,
|
// but if there are obscure db errors, this will also be returned,
|
||||||
// even though it is not correct.
|
// even though it is not correct.
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,7 @@ type Database interface {
|
||||||
GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
|
GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
|
||||||
RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error)
|
RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error)
|
||||||
MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error)
|
MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error)
|
||||||
|
GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error)
|
||||||
|
|
||||||
RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
|
RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
|
||||||
|
|
||||||
|
|
@ -80,7 +81,7 @@ type Database interface {
|
||||||
// Returns a map following the format data[roomID] = []dataTypes
|
// Returns a map following the format data[roomID] = []dataTypes
|
||||||
// If no data is retrieved, returns an empty map
|
// If no data is retrieved, returns an empty map
|
||||||
// If there was an issue with the retrieval, returns an error
|
// If there was an issue with the retrieval, returns an error
|
||||||
GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter) (map[string][]string, error)
|
GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter) (map[string][]string, types.StreamPosition, error)
|
||||||
// UpsertAccountData keeps track of new or updated account data, by saving the type
|
// UpsertAccountData keeps track of new or updated account data, by saving the type
|
||||||
// of the new/updated data, and the user ID and room ID the data is related to (empty)
|
// of the new/updated data, and the user ID and room ID the data is related to (empty)
|
||||||
// room ID means the data isn't specific to any room)
|
// room ID means the data isn't specific to any room)
|
||||||
|
|
@ -124,10 +125,10 @@ type Database interface {
|
||||||
// CleanSendToDeviceUpdates removes all send-to-device messages BEFORE the specified
|
// CleanSendToDeviceUpdates removes all send-to-device messages BEFORE the specified
|
||||||
// from position, preventing the send-to-device table from growing indefinitely.
|
// from position, preventing the send-to-device table from growing indefinitely.
|
||||||
CleanSendToDeviceUpdates(ctx context.Context, userID, deviceID string, before types.StreamPosition) (err error)
|
CleanSendToDeviceUpdates(ctx context.Context, userID, deviceID string, before types.StreamPosition) (err error)
|
||||||
// GetFilter looks up the filter associated with a given local user and filter ID.
|
// GetFilter looks up the filter associated with a given local user and filter ID
|
||||||
// Returns a filter structure. Otherwise returns an error if no such filter exists
|
// and populates the target filter. Otherwise returns an error if no such filter exists
|
||||||
// or if there was an error talking to the database.
|
// or if there was an error talking to the database.
|
||||||
GetFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error)
|
GetFilter(ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string) error
|
||||||
// PutFilter puts the passed filter into the database.
|
// PutFilter puts the passed filter into the database.
|
||||||
// Returns the filterID as a string. Otherwise returns an error if something
|
// Returns the filterID as a string. Otherwise returns an error if something
|
||||||
// goes wrong.
|
// goes wrong.
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ const insertAccountDataSQL = "" +
|
||||||
" RETURNING id"
|
" RETURNING id"
|
||||||
|
|
||||||
const selectAccountDataInRangeSQL = "" +
|
const selectAccountDataInRangeSQL = "" +
|
||||||
"SELECT room_id, type FROM syncapi_account_data_type" +
|
"SELECT id, room_id, type FROM syncapi_account_data_type" +
|
||||||
" WHERE user_id = $1 AND id > $2 AND id <= $3" +
|
" WHERE user_id = $1 AND id > $2 AND id <= $3" +
|
||||||
" AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" +
|
" AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" +
|
||||||
" AND ( $5::text[] IS NULL OR NOT(type LIKE ANY($5)) )" +
|
" AND ( $5::text[] IS NULL OR NOT(type LIKE ANY($5)) )" +
|
||||||
|
|
@ -103,7 +103,7 @@ func (s *accountDataStatements) SelectAccountDataInRange(
|
||||||
userID string,
|
userID string,
|
||||||
r types.Range,
|
r types.Range,
|
||||||
accountDataEventFilter *gomatrixserverlib.EventFilter,
|
accountDataEventFilter *gomatrixserverlib.EventFilter,
|
||||||
) (data map[string][]string, err error) {
|
) (data map[string][]string, pos types.StreamPosition, err error) {
|
||||||
data = make(map[string][]string)
|
data = make(map[string][]string)
|
||||||
|
|
||||||
rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High(),
|
rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High(),
|
||||||
|
|
@ -116,11 +116,12 @@ func (s *accountDataStatements) SelectAccountDataInRange(
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed")
|
||||||
|
|
||||||
for rows.Next() {
|
var dataType string
|
||||||
var dataType string
|
var roomID string
|
||||||
var roomID string
|
var id types.StreamPosition
|
||||||
|
|
||||||
if err = rows.Scan(&roomID, &dataType); err != nil {
|
for rows.Next() {
|
||||||
|
if err = rows.Scan(&id, &roomID, &dataType); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -129,8 +130,14 @@ func (s *accountDataStatements) SelectAccountDataInRange(
|
||||||
} else {
|
} else {
|
||||||
data[roomID] = []string{dataType}
|
data[roomID] = []string{dataType}
|
||||||
}
|
}
|
||||||
|
if id > pos {
|
||||||
|
pos = id
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return data, rows.Err()
|
if pos == 0 {
|
||||||
|
pos = r.High()
|
||||||
|
}
|
||||||
|
return data, pos, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) SelectMaxAccountDataID(
|
func (s *accountDataStatements) SelectMaxAccountDataID(
|
||||||
|
|
|
||||||
|
|
@ -73,21 +73,20 @@ func NewPostgresFilterTable(db *sql.DB) (tables.Filter, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *filterStatements) SelectFilter(
|
func (s *filterStatements) SelectFilter(
|
||||||
ctx context.Context, localpart string, filterID string,
|
ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string,
|
||||||
) (*gomatrixserverlib.Filter, error) {
|
) error {
|
||||||
// Retrieve filter from database (stored as canonical JSON)
|
// Retrieve filter from database (stored as canonical JSON)
|
||||||
var filterData []byte
|
var filterData []byte
|
||||||
err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
|
err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unmarshal JSON into Filter struct
|
// Unmarshal JSON into Filter struct
|
||||||
filter := gomatrixserverlib.DefaultFilter()
|
if err = json.Unmarshal(filterData, &target); err != nil {
|
||||||
if err = json.Unmarshal(filterData, &filter); err != nil {
|
return err
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
return &filter, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *filterStatements) InsertFilter(
|
func (s *filterStatements) InsertFilter(
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,8 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/lib/pq"
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
|
@ -61,9 +63,13 @@ const selectMembershipCountSQL = "" +
|
||||||
" SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" +
|
" SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" +
|
||||||
") t WHERE t.membership = $3"
|
") t WHERE t.membership = $3"
|
||||||
|
|
||||||
|
const selectHeroesSQL = "" +
|
||||||
|
"SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership = ANY($3) LIMIT 5"
|
||||||
|
|
||||||
type membershipsStatements struct {
|
type membershipsStatements struct {
|
||||||
upsertMembershipStmt *sql.Stmt
|
upsertMembershipStmt *sql.Stmt
|
||||||
selectMembershipCountStmt *sql.Stmt
|
selectMembershipCountStmt *sql.Stmt
|
||||||
|
selectHeroesStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
|
func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
|
||||||
|
|
@ -72,13 +78,11 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil {
|
return s, sqlutil.StatementList{
|
||||||
return nil, err
|
{&s.upsertMembershipStmt, upsertMembershipSQL},
|
||||||
}
|
{&s.selectMembershipCountStmt, selectMembershipCountSQL},
|
||||||
if s.selectMembershipCountStmt, err = db.Prepare(selectMembershipCountSQL); err != nil {
|
{&s.selectHeroesStmt, selectHeroesSQL},
|
||||||
return nil, err
|
}.Prepare(db)
|
||||||
}
|
|
||||||
return s, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipsStatements) UpsertMembership(
|
func (s *membershipsStatements) UpsertMembership(
|
||||||
|
|
@ -108,3 +112,23 @@ func (s *membershipsStatements) SelectMembershipCount(
|
||||||
err = stmt.QueryRowContext(ctx, roomID, pos, membership).Scan(&count)
|
err = stmt.QueryRowContext(ctx, roomID, pos, membership).Scan(&count)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *membershipsStatements) SelectHeroes(
|
||||||
|
ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string,
|
||||||
|
) (heroes []string, err error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.selectHeroesStmt)
|
||||||
|
var rows *sql.Rows
|
||||||
|
rows, err = stmt.QueryContext(ctx, roomID, userID, pq.StringArray(memberships))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "SelectHeroes: rows.close() failed")
|
||||||
|
var hero string
|
||||||
|
for rows.Next() {
|
||||||
|
if err = rows.Scan(&hero); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
heroes = append(heroes, hero)
|
||||||
|
}
|
||||||
|
return heroes, rows.Err()
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -127,6 +127,9 @@ func (p *presenceStatements) GetPresenceForUser(
|
||||||
}
|
}
|
||||||
stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt)
|
stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt)
|
||||||
err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS)
|
err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
result.ClientFields.Presence = result.Presence.String()
|
result.ClientFields.Presence = result.Presence.String()
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -124,6 +124,10 @@ func (d *Database) MembershipCount(ctx context.Context, roomID, membership strin
|
||||||
return d.Memberships.SelectMembershipCount(ctx, nil, roomID, membership, pos)
|
return d.Memberships.SelectMembershipCount(ctx, nil, roomID, membership, pos)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) {
|
||||||
|
return d.Memberships.SelectHeroes(ctx, nil, roomID, userID, memberships)
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Database) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) {
|
func (d *Database) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) {
|
||||||
return d.OutputEvents.SelectRecentEvents(ctx, nil, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents)
|
return d.OutputEvents.SelectRecentEvents(ctx, nil, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents)
|
||||||
}
|
}
|
||||||
|
|
@ -261,7 +265,7 @@ func (d *Database) DeletePeeks(
|
||||||
func (d *Database) GetAccountDataInRange(
|
func (d *Database) GetAccountDataInRange(
|
||||||
ctx context.Context, userID string, r types.Range,
|
ctx context.Context, userID string, r types.Range,
|
||||||
accountDataFilterPart *gomatrixserverlib.EventFilter,
|
accountDataFilterPart *gomatrixserverlib.EventFilter,
|
||||||
) (map[string][]string, error) {
|
) (map[string][]string, types.StreamPosition, error) {
|
||||||
return d.AccountData.SelectAccountDataInRange(ctx, userID, r, accountDataFilterPart)
|
return d.AccountData.SelectAccountDataInRange(ctx, userID, r, accountDataFilterPart)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -509,9 +513,9 @@ func (d *Database) StreamToTopologicalPosition(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetFilter(
|
func (d *Database) GetFilter(
|
||||||
ctx context.Context, localpart string, filterID string,
|
ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string,
|
||||||
) (*gomatrixserverlib.Filter, error) {
|
) error {
|
||||||
return d.Filter.SelectFilter(ctx, localpart, filterID)
|
return d.Filter.SelectFilter(ctx, target, localpart, filterID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) PutFilter(
|
func (d *Database) PutFilter(
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ const insertAccountDataSQL = "" +
|
||||||
|
|
||||||
// further parameters are added by prepareWithFilters
|
// further parameters are added by prepareWithFilters
|
||||||
const selectAccountDataInRangeSQL = "" +
|
const selectAccountDataInRangeSQL = "" +
|
||||||
"SELECT room_id, type FROM syncapi_account_data_type" +
|
"SELECT id, room_id, type FROM syncapi_account_data_type" +
|
||||||
" WHERE user_id = $1 AND id > $2 AND id <= $3"
|
" WHERE user_id = $1 AND id > $2 AND id <= $3"
|
||||||
|
|
||||||
const selectMaxAccountDataIDSQL = "" +
|
const selectMaxAccountDataIDSQL = "" +
|
||||||
|
|
@ -95,7 +95,7 @@ func (s *accountDataStatements) SelectAccountDataInRange(
|
||||||
userID string,
|
userID string,
|
||||||
r types.Range,
|
r types.Range,
|
||||||
filter *gomatrixserverlib.EventFilter,
|
filter *gomatrixserverlib.EventFilter,
|
||||||
) (data map[string][]string, err error) {
|
) (data map[string][]string, pos types.StreamPosition, err error) {
|
||||||
data = make(map[string][]string)
|
data = make(map[string][]string)
|
||||||
stmt, params, err := prepareWithFilters(
|
stmt, params, err := prepareWithFilters(
|
||||||
s.db, nil, selectAccountDataInRangeSQL,
|
s.db, nil, selectAccountDataInRangeSQL,
|
||||||
|
|
@ -112,11 +112,12 @@ func (s *accountDataStatements) SelectAccountDataInRange(
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed")
|
||||||
|
|
||||||
for rows.Next() {
|
var dataType string
|
||||||
var dataType string
|
var roomID string
|
||||||
var roomID string
|
var id types.StreamPosition
|
||||||
|
|
||||||
if err = rows.Scan(&roomID, &dataType); err != nil {
|
for rows.Next() {
|
||||||
|
if err = rows.Scan(&id, &roomID, &dataType); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -125,9 +126,14 @@ func (s *accountDataStatements) SelectAccountDataInRange(
|
||||||
} else {
|
} else {
|
||||||
data[roomID] = []string{dataType}
|
data[roomID] = []string{dataType}
|
||||||
}
|
}
|
||||||
|
if id > pos {
|
||||||
|
pos = id
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
if pos == 0 {
|
||||||
return data, nil
|
pos = r.High()
|
||||||
|
}
|
||||||
|
return data, pos, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) SelectMaxAccountDataID(
|
func (s *accountDataStatements) SelectMaxAccountDataID(
|
||||||
|
|
|
||||||
|
|
@ -77,21 +77,20 @@ func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *filterStatements) SelectFilter(
|
func (s *filterStatements) SelectFilter(
|
||||||
ctx context.Context, localpart string, filterID string,
|
ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string,
|
||||||
) (*gomatrixserverlib.Filter, error) {
|
) error {
|
||||||
// Retrieve filter from database (stored as canonical JSON)
|
// Retrieve filter from database (stored as canonical JSON)
|
||||||
var filterData []byte
|
var filterData []byte
|
||||||
err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
|
err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unmarshal JSON into Filter struct
|
// Unmarshal JSON into Filter struct
|
||||||
filter := gomatrixserverlib.DefaultFilter()
|
if err = json.Unmarshal(filterData, &target); err != nil {
|
||||||
if err = json.Unmarshal(filterData, &filter); err != nil {
|
return err
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
return &filter, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *filterStatements) InsertFilter(
|
func (s *filterStatements) InsertFilter(
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,9 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
|
@ -61,10 +63,14 @@ const selectMembershipCountSQL = "" +
|
||||||
" SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" +
|
" SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" +
|
||||||
") t WHERE t.membership = $3"
|
") t WHERE t.membership = $3"
|
||||||
|
|
||||||
|
const selectHeroesSQL = "" +
|
||||||
|
"SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership IN ($3) LIMIT 5"
|
||||||
|
|
||||||
type membershipsStatements struct {
|
type membershipsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
upsertMembershipStmt *sql.Stmt
|
upsertMembershipStmt *sql.Stmt
|
||||||
selectMembershipCountStmt *sql.Stmt
|
selectMembershipCountStmt *sql.Stmt
|
||||||
|
//selectHeroesStmt *sql.Stmt - prepared at runtime due to variadic
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) {
|
func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) {
|
||||||
|
|
@ -75,13 +81,11 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil {
|
return s, sqlutil.StatementList{
|
||||||
return nil, err
|
{&s.upsertMembershipStmt, upsertMembershipSQL},
|
||||||
}
|
{&s.selectMembershipCountStmt, selectMembershipCountSQL},
|
||||||
if s.selectMembershipCountStmt, err = db.Prepare(selectMembershipCountSQL); err != nil {
|
// {&s.selectHeroesStmt, selectHeroesSQL}, - prepared at runtime due to variadic
|
||||||
return nil, err
|
}.Prepare(db)
|
||||||
}
|
|
||||||
return s, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipsStatements) UpsertMembership(
|
func (s *membershipsStatements) UpsertMembership(
|
||||||
|
|
@ -111,3 +115,36 @@ func (s *membershipsStatements) SelectMembershipCount(
|
||||||
err = stmt.QueryRowContext(ctx, roomID, pos, membership).Scan(&count)
|
err = stmt.QueryRowContext(ctx, roomID, pos, membership).Scan(&count)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *membershipsStatements) SelectHeroes(
|
||||||
|
ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string,
|
||||||
|
) (heroes []string, err error) {
|
||||||
|
stmtSQL := strings.Replace(selectHeroesSQL, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1)
|
||||||
|
stmt, err := s.db.PrepareContext(ctx, stmtSQL)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, stmt, "SelectHeroes: stmt.close() failed")
|
||||||
|
params := []interface{}{
|
||||||
|
roomID, userID,
|
||||||
|
}
|
||||||
|
for _, membership := range memberships {
|
||||||
|
params = append(params, membership)
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
|
var rows *sql.Rows
|
||||||
|
rows, err = stmt.QueryContext(ctx, params...)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "SelectHeroes: rows.close() failed")
|
||||||
|
var hero string
|
||||||
|
for rows.Next() {
|
||||||
|
if err = rows.Scan(&hero); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
heroes = append(heroes, hero)
|
||||||
|
}
|
||||||
|
return heroes, rows.Err()
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -142,6 +142,9 @@ func (p *presenceStatements) GetPresenceForUser(
|
||||||
}
|
}
|
||||||
stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt)
|
stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt)
|
||||||
err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS)
|
err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
result.ClientFields.Presence = result.Presence.String()
|
result.ClientFields.Presence = result.Presence.String()
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ import (
|
||||||
type AccountData interface {
|
type AccountData interface {
|
||||||
InsertAccountData(ctx context.Context, txn *sql.Tx, userID, roomID, dataType string) (pos types.StreamPosition, err error)
|
InsertAccountData(ctx context.Context, txn *sql.Tx, userID, roomID, dataType string) (pos types.StreamPosition, err error)
|
||||||
// SelectAccountDataInRange returns a map of room ID to a list of `dataType`.
|
// SelectAccountDataInRange returns a map of room ID to a list of `dataType`.
|
||||||
SelectAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataEventFilter *gomatrixserverlib.EventFilter) (data map[string][]string, err error)
|
SelectAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataEventFilter *gomatrixserverlib.EventFilter) (data map[string][]string, pos types.StreamPosition, err error)
|
||||||
SelectMaxAccountDataID(ctx context.Context, txn *sql.Tx) (id int64, err error)
|
SelectMaxAccountDataID(ctx context.Context, txn *sql.Tx) (id int64, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -157,7 +157,7 @@ type SendToDevice interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Filter interface {
|
type Filter interface {
|
||||||
SelectFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error)
|
SelectFilter(ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string) error
|
||||||
InsertFilter(ctx context.Context, filter *gomatrixserverlib.Filter, localpart string) (filterID string, err error)
|
InsertFilter(ctx context.Context, filter *gomatrixserverlib.Filter, localpart string) (filterID string, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -170,6 +170,7 @@ type Receipts interface {
|
||||||
type Memberships interface {
|
type Memberships interface {
|
||||||
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
|
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
|
||||||
SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error)
|
SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error)
|
||||||
|
SelectHeroes(ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string) (heroes []string, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type NotificationData interface {
|
type NotificationData interface {
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ func (p *AccountDataStreamProvider) IncrementalSync(
|
||||||
To: to,
|
To: to,
|
||||||
}
|
}
|
||||||
|
|
||||||
dataTypes, err := p.DB.GetAccountDataInRange(
|
dataTypes, pos, err := p.DB.GetAccountDataInRange(
|
||||||
ctx, req.Device.UserID, r, &req.Filter.AccountData,
|
ctx, req.Device.UserID, r, &req.Filter.AccountData,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -53,6 +53,12 @@ func (p *AccountDataStreamProvider) IncrementalSync(
|
||||||
|
|
||||||
// Iterate over the rooms
|
// Iterate over the rooms
|
||||||
for roomID, dataTypes := range dataTypes {
|
for roomID, dataTypes := range dataTypes {
|
||||||
|
// For a complete sync, make sure we're only including this room if
|
||||||
|
// that room was present in the joined rooms.
|
||||||
|
if from == 0 && roomID != "" && !req.IsRoomPresent(roomID) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// Request the missing data from the database
|
// Request the missing data from the database
|
||||||
for _, dataType := range dataTypes {
|
for _, dataType := range dataTypes {
|
||||||
dataReq := userapi.QueryAccountDataRequest{
|
dataReq := userapi.QueryAccountDataRequest{
|
||||||
|
|
@ -95,5 +101,5 @@ func (p *AccountDataStreamProvider) IncrementalSync(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return to
|
return pos
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,13 +4,16 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sort"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/caching"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
|
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
"go.uber.org/atomic"
|
"go.uber.org/atomic"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -30,6 +33,7 @@ type PDUStreamProvider struct {
|
||||||
workers atomic.Int32
|
workers atomic.Int32
|
||||||
// userID+deviceID -> lazy loading cache
|
// userID+deviceID -> lazy loading cache
|
||||||
lazyLoadCache *caching.LazyLoadCache
|
lazyLoadCache *caching.LazyLoadCache
|
||||||
|
rsAPI roomserverAPI.RoomserverInternalAPI
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PDUStreamProvider) worker() {
|
func (p *PDUStreamProvider) worker() {
|
||||||
|
|
@ -290,16 +294,11 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Work out how many members are in the room.
|
|
||||||
joinedCount, _ := p.DB.MembershipCount(ctx, delta.RoomID, gomatrixserverlib.Join, latestPosition)
|
|
||||||
invitedCount, _ := p.DB.MembershipCount(ctx, delta.RoomID, gomatrixserverlib.Invite, latestPosition)
|
|
||||||
|
|
||||||
switch delta.Membership {
|
switch delta.Membership {
|
||||||
case gomatrixserverlib.Join:
|
case gomatrixserverlib.Join:
|
||||||
jr := types.NewJoinResponse()
|
jr := types.NewJoinResponse()
|
||||||
if hasMembershipChange {
|
if hasMembershipChange {
|
||||||
jr.Summary.JoinedMemberCount = &joinedCount
|
p.addRoomSummary(ctx, jr, delta.RoomID, device.UserID, latestPosition)
|
||||||
jr.Summary.InvitedMemberCount = &invitedCount
|
|
||||||
}
|
}
|
||||||
jr.Timeline.PrevBatch = &prevBatch
|
jr.Timeline.PrevBatch = &prevBatch
|
||||||
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
|
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
|
||||||
|
|
@ -332,6 +331,45 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
||||||
return latestPosition, nil
|
return latestPosition, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) {
|
||||||
|
// Work out how many members are in the room.
|
||||||
|
joinedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition)
|
||||||
|
invitedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, latestPosition)
|
||||||
|
|
||||||
|
jr.Summary.JoinedMemberCount = &joinedCount
|
||||||
|
jr.Summary.InvitedMemberCount = &invitedCount
|
||||||
|
|
||||||
|
fetchStates := []gomatrixserverlib.StateKeyTuple{
|
||||||
|
{EventType: gomatrixserverlib.MRoomName},
|
||||||
|
{EventType: gomatrixserverlib.MRoomCanonicalAlias},
|
||||||
|
}
|
||||||
|
// Check if the room has a name or a canonical alias
|
||||||
|
latestState := &roomserverAPI.QueryLatestEventsAndStateResponse{}
|
||||||
|
err := p.rsAPI.QueryLatestEventsAndState(ctx, &roomserverAPI.QueryLatestEventsAndStateRequest{StateToFetch: fetchStates, RoomID: roomID}, latestState)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Check if the room has a name or canonical alias, if so, return.
|
||||||
|
for _, ev := range latestState.StateEvents {
|
||||||
|
switch ev.Type() {
|
||||||
|
case gomatrixserverlib.MRoomName:
|
||||||
|
if gjson.GetBytes(ev.Content(), "name").Str != "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case gomatrixserverlib.MRoomCanonicalAlias:
|
||||||
|
if gjson.GetBytes(ev.Content(), "alias").Str != "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
heroes, err := p.DB.GetRoomHeroes(ctx, roomID, userID, []string{"join", "invite"})
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sort.Strings(heroes)
|
||||||
|
jr.Summary.Heroes = heroes
|
||||||
|
}
|
||||||
|
|
||||||
func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
roomID string,
|
roomID string,
|
||||||
|
|
@ -416,9 +454,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
||||||
prevBatch.Decrement()
|
prevBatch.Decrement()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Work out how many members are in the room.
|
p.addRoomSummary(ctx, jr, roomID, device.UserID, r.From)
|
||||||
joinedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Join, r.From)
|
|
||||||
invitedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, r.From)
|
|
||||||
|
|
||||||
// We don't include a device here as we don't need to send down
|
// We don't include a device here as we don't need to send down
|
||||||
// transaction IDs for complete syncs, but we do it anyway because Sytest demands it for:
|
// transaction IDs for complete syncs, but we do it anyway because Sytest demands it for:
|
||||||
|
|
@ -439,8 +475,6 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
jr.Summary.JoinedMemberCount = &joinedCount
|
|
||||||
jr.Summary.InvitedMemberCount = &invitedCount
|
|
||||||
jr.Timeline.PrevBatch = prevBatch
|
jr.Timeline.PrevBatch = prevBatch
|
||||||
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
|
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
|
||||||
jr.Timeline.Limited = limited
|
jr.Timeline.Limited = limited
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@ package streams
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
|
@ -80,11 +79,10 @@ func (p *PresenceStreamProvider) IncrementalSync(
|
||||||
if _, ok := presences[roomUsers[i]]; ok {
|
if _, ok := presences[roomUsers[i]]; ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// Bear in mind that this might return nil, but at least populating
|
||||||
|
// a nil means that there's a map entry so we won't repeat this call.
|
||||||
presences[roomUsers[i]], err = p.DB.GetPresence(ctx, roomUsers[i])
|
presences[roomUsers[i]], err = p.DB.GetPresence(ctx, roomUsers[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
req.Log.WithError(err).Error("unable to query presence for user")
|
req.Log.WithError(err).Error("unable to query presence for user")
|
||||||
return from
|
return from
|
||||||
}
|
}
|
||||||
|
|
@ -93,8 +91,10 @@ func (p *PresenceStreamProvider) IncrementalSync(
|
||||||
}
|
}
|
||||||
|
|
||||||
lastPos := to
|
lastPos := to
|
||||||
for i := range presences {
|
for _, presence := range presences {
|
||||||
presence := presences[i]
|
if presence == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
// Ignore users we don't share a room with
|
// Ignore users we don't share a room with
|
||||||
if req.Device.UserID != presence.UserID && !p.notifier.IsSharedUser(req.Device.UserID, presence.UserID) {
|
if req.Device.UserID != presence.UserID && !p.notifier.IsSharedUser(req.Device.UserID, presence.UserID) {
|
||||||
continue
|
continue
|
||||||
|
|
|
||||||
|
|
@ -62,6 +62,12 @@ func (p *ReceiptStreamProvider) IncrementalSync(
|
||||||
}
|
}
|
||||||
|
|
||||||
for roomID, receipts := range receiptsByRoom {
|
for roomID, receipts := range receiptsByRoom {
|
||||||
|
// For a complete sync, make sure we're only including this room if
|
||||||
|
// that room was present in the joined rooms.
|
||||||
|
if from == 0 && !req.IsRoomPresent(roomID) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
jr := *types.NewJoinResponse()
|
jr := *types.NewJoinResponse()
|
||||||
if existing, ok := req.Response.Rooms.Join[roomID]; ok {
|
if existing, ok := req.Response.Rooms.Join[roomID]; ok {
|
||||||
jr = existing
|
jr = existing
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@ func NewSyncStreamProviders(
|
||||||
PDUStreamProvider: &PDUStreamProvider{
|
PDUStreamProvider: &PDUStreamProvider{
|
||||||
StreamProvider: StreamProvider{DB: d},
|
StreamProvider: StreamProvider{DB: d},
|
||||||
lazyLoadCache: lazyLoadCache,
|
lazyLoadCache: lazyLoadCache,
|
||||||
|
rsAPI: rsAPI,
|
||||||
},
|
},
|
||||||
TypingStreamProvider: &TypingStreamProvider{
|
TypingStreamProvider: &TypingStreamProvider{
|
||||||
StreamProvider: StreamProvider{DB: d},
|
StreamProvider: StreamProvider{DB: d},
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -47,6 +48,13 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
|
||||||
}
|
}
|
||||||
// TODO: read from stored filters too
|
// TODO: read from stored filters too
|
||||||
filter := gomatrixserverlib.DefaultFilter()
|
filter := gomatrixserverlib.DefaultFilter()
|
||||||
|
if since.IsEmpty() {
|
||||||
|
// Send as much account data down for complete syncs as possible
|
||||||
|
// by default, otherwise clients do weird things while waiting
|
||||||
|
// for the rest of the data to trickle down.
|
||||||
|
filter.AccountData.Limit = math.MaxInt32
|
||||||
|
filter.Room.AccountData.Limit = math.MaxInt32
|
||||||
|
}
|
||||||
filterQuery := req.URL.Query().Get("filter")
|
filterQuery := req.URL.Query().Get("filter")
|
||||||
if filterQuery != "" {
|
if filterQuery != "" {
|
||||||
if filterQuery[0] == '{' {
|
if filterQuery[0] == '{' {
|
||||||
|
|
@ -61,11 +69,9 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||||
return nil, fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
|
return nil, fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
|
||||||
}
|
}
|
||||||
if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err != nil && err != sql.ErrNoRows {
|
if err := syncDB.GetFilter(req.Context(), &filter, localpart, filterQuery); err != nil && err != sql.ErrNoRows {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("syncDB.GetFilter failed")
|
util.GetLogger(req.Context()).WithError(err).Error("syncDB.GetFilter failed")
|
||||||
return nil, fmt.Errorf("syncDB.GetFilter: %w", err)
|
return nil, fmt.Errorf("syncDB.GetFilter: %w", err)
|
||||||
} else if f != nil {
|
|
||||||
filter = *f
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -127,14 +127,23 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user
|
||||||
if !ok { // this should almost never happen
|
if !ok { // this should almost never happen
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
newPresence := types.PresenceInternal{
|
newPresence := types.PresenceInternal{
|
||||||
ClientFields: types.PresenceClientResponse{
|
|
||||||
Presence: presenceID.String(),
|
|
||||||
},
|
|
||||||
Presence: presenceID,
|
Presence: presenceID,
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
LastActiveTS: gomatrixserverlib.AsTimestamp(time.Now()),
|
LastActiveTS: gomatrixserverlib.AsTimestamp(time.Now()),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ensure we also send the current status_msg to federated servers and not nil
|
||||||
|
dbPresence, err := db.GetPresence(context.Background(), userID)
|
||||||
|
if err != nil && err != sql.ErrNoRows {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if dbPresence != nil {
|
||||||
|
newPresence.ClientFields = dbPresence.ClientFields
|
||||||
|
}
|
||||||
|
newPresence.ClientFields.Presence = presenceID.String()
|
||||||
|
|
||||||
defer rp.presence.Store(userID, newPresence)
|
defer rp.presence.Store(userID, newPresence)
|
||||||
// avoid spamming presence updates when syncing
|
// avoid spamming presence updates when syncing
|
||||||
existingPresence, ok := rp.presence.LoadOrStore(userID, newPresence)
|
existingPresence, ok := rp.presence.LoadOrStore(userID, newPresence)
|
||||||
|
|
@ -145,13 +154,7 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensure we also send the current status_msg to federated servers and not nil
|
if err := rp.producer.SendPresence(userID, presenceID, newPresence.ClientFields.StatusMsg); err != nil {
|
||||||
dbPresence, err := db.GetPresence(context.Background(), userID)
|
|
||||||
if err != nil && err != sql.ErrNoRows {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := rp.producer.SendPresence(userID, presenceID, dbPresence.ClientFields.StatusMsg); err != nil {
|
|
||||||
logrus.WithError(err).Error("Unable to publish presence message from sync")
|
logrus.WithError(err).Error("Unable to publish presence message from sync")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,23 @@ type SyncRequest struct {
|
||||||
IgnoredUsers IgnoredUsers
|
IgnoredUsers IgnoredUsers
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *SyncRequest) IsRoomPresent(roomID string) bool {
|
||||||
|
membership, ok := r.Rooms[roomID]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch membership {
|
||||||
|
case gomatrixserverlib.Join:
|
||||||
|
return true
|
||||||
|
case gomatrixserverlib.Invite:
|
||||||
|
return true
|
||||||
|
case gomatrixserverlib.Peek:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type StreamProvider interface {
|
type StreamProvider interface {
|
||||||
Setup()
|
Setup()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -681,8 +681,6 @@ GET /presence/:user_id/status fetches initial status
|
||||||
PUT /presence/:user_id/status updates my presence
|
PUT /presence/:user_id/status updates my presence
|
||||||
Presence change reports an event to myself
|
Presence change reports an event to myself
|
||||||
Existing members see new members' presence
|
Existing members see new members' presence
|
||||||
#Existing members see new member's presence
|
|
||||||
Newly joined room includes presence in incremental sync
|
|
||||||
Get presence for newly joined members in incremental sync
|
Get presence for newly joined members in incremental sync
|
||||||
User sees their own presence in a sync
|
User sees their own presence in a sync
|
||||||
User sees updates to presence from other users in the incremental sync.
|
User sees updates to presence from other users in the incremental sync.
|
||||||
|
|
@ -713,4 +711,8 @@ Presence can be set from sync
|
||||||
/state returns M_NOT_FOUND for a rejected message event
|
/state returns M_NOT_FOUND for a rejected message event
|
||||||
/state_ids returns M_NOT_FOUND for a rejected message event
|
/state_ids returns M_NOT_FOUND for a rejected message event
|
||||||
/state returns M_NOT_FOUND for a rejected state event
|
/state returns M_NOT_FOUND for a rejected state event
|
||||||
/state_ids returns M_NOT_FOUND for a rejected state event
|
/state_ids returns M_NOT_FOUND for a rejected state event
|
||||||
|
PUT /rooms/:room_id/redact/:event_id/:txn_id is idempotent
|
||||||
|
Unnamed room comes with a name summary
|
||||||
|
Named room comes with just joined member count summary
|
||||||
|
Room summary only has 5 heroes
|
||||||
|
|
@ -27,18 +27,24 @@ import (
|
||||||
type Profile interface {
|
type Profile interface {
|
||||||
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
|
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
|
||||||
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
||||||
SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
|
|
||||||
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
|
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
|
||||||
SetDisplayName(ctx context.Context, localpart string, displayName string) error
|
SetDisplayName(ctx context.Context, localpart string, displayName string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Database interface {
|
type Account interface {
|
||||||
Profile
|
|
||||||
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
|
|
||||||
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
||||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||||
// account already exists, it will return nil, ErrUserExists.
|
// account already exists, it will return nil, ErrUserExists.
|
||||||
CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
||||||
|
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
|
||||||
|
GetNewNumericLocalpart(ctx context.Context) (int64, error)
|
||||||
|
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
|
||||||
|
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
||||||
|
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
||||||
|
SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type AccountData interface {
|
||||||
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
|
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
|
||||||
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
|
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
|
||||||
// GetAccountDataByType returns account data matching a given
|
// GetAccountDataByType returns account data matching a given
|
||||||
|
|
@ -46,26 +52,9 @@ type Database interface {
|
||||||
// If no account data could be found, returns nil
|
// If no account data could be found, returns nil
|
||||||
// Returns an error if there was an issue with the retrieval
|
// Returns an error if there was an issue with the retrieval
|
||||||
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
|
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
|
||||||
GetNewNumericLocalpart(ctx context.Context) (int64, error)
|
}
|
||||||
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
|
|
||||||
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
|
|
||||||
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
|
|
||||||
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
|
|
||||||
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
|
|
||||||
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
|
||||||
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
|
||||||
CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error)
|
|
||||||
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
|
|
||||||
|
|
||||||
// Key backups
|
|
||||||
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)
|
|
||||||
UpdateKeyBackupAuthData(ctx context.Context, userID, version string, authData json.RawMessage) (err error)
|
|
||||||
DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error)
|
|
||||||
GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error)
|
|
||||||
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
|
|
||||||
GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
|
|
||||||
CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
|
|
||||||
|
|
||||||
|
type Device interface {
|
||||||
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
|
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
|
||||||
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
|
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
|
||||||
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
|
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
|
||||||
|
|
@ -79,11 +68,22 @@ type Database interface {
|
||||||
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
|
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
|
||||||
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
|
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
|
||||||
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error
|
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error
|
||||||
RemoveDevice(ctx context.Context, deviceID, localpart string) error
|
|
||||||
RemoveDevices(ctx context.Context, localpart string, devices []string) error
|
RemoveDevices(ctx context.Context, localpart string, devices []string) error
|
||||||
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
|
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
|
||||||
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
|
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type KeyBackup interface {
|
||||||
|
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)
|
||||||
|
UpdateKeyBackupAuthData(ctx context.Context, userID, version string, authData json.RawMessage) (err error)
|
||||||
|
DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error)
|
||||||
|
GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error)
|
||||||
|
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
|
||||||
|
GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
|
||||||
|
CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type LoginToken interface {
|
||||||
// CreateLoginToken generates a token, stores and returns it. The lifetime is
|
// CreateLoginToken generates a token, stores and returns it. The lifetime is
|
||||||
// determined by the loginTokenLifetime given to the Database constructor.
|
// determined by the loginTokenLifetime given to the Database constructor.
|
||||||
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
|
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
|
||||||
|
|
@ -94,21 +94,50 @@ type Database interface {
|
||||||
// GetLoginTokenDataByToken returns the data associated with the given token.
|
// GetLoginTokenDataByToken returns the data associated with the given token.
|
||||||
// May return sql.ErrNoRows.
|
// May return sql.ErrNoRows.
|
||||||
GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
|
GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
|
||||||
|
}
|
||||||
|
|
||||||
InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error
|
type OpenID interface {
|
||||||
DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error)
|
CreateOpenIDToken(ctx context.Context, token, userID string) (exp int64, err error)
|
||||||
SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error)
|
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
|
||||||
GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
|
}
|
||||||
GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error)
|
|
||||||
GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)
|
|
||||||
DeleteOldNotifications(ctx context.Context) error
|
|
||||||
|
|
||||||
|
type Pusher interface {
|
||||||
UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error
|
UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error
|
||||||
GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error)
|
GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error)
|
||||||
RemovePusher(ctx context.Context, appid, pushkey, localpart string) error
|
RemovePusher(ctx context.Context, appid, pushkey, localpart string) error
|
||||||
RemovePushers(ctx context.Context, appid, pushkey string) error
|
RemovePushers(ctx context.Context, appid, pushkey string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ThreePID interface {
|
||||||
|
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
|
||||||
|
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
|
||||||
|
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
|
||||||
|
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Notification interface {
|
||||||
|
InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error
|
||||||
|
DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error)
|
||||||
|
SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, read bool) (affected bool, err error)
|
||||||
|
GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
|
||||||
|
GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error)
|
||||||
|
GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)
|
||||||
|
DeleteOldNotifications(ctx context.Context) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Database interface {
|
||||||
|
Account
|
||||||
|
AccountData
|
||||||
|
Device
|
||||||
|
KeyBackup
|
||||||
|
LoginToken
|
||||||
|
Notification
|
||||||
|
OpenID
|
||||||
|
Profile
|
||||||
|
Pusher
|
||||||
|
ThreePID
|
||||||
|
}
|
||||||
|
|
||||||
// Err3PIDInUse is the error returned when trying to save an association involving
|
// Err3PIDInUse is the error returned when trying to save an association involving
|
||||||
// a third-party identifier which is already associated to a local user.
|
// a third-party identifier which is already associated to a local user.
|
||||||
var Err3PIDInUse = errors.New("this third-party identifier is already in use")
|
var Err3PIDInUse = errors.New("this third-party identifier is already in use")
|
||||||
|
|
|
||||||
|
|
@ -47,8 +47,6 @@ CREATE TABLE IF NOT EXISTS account_accounts (
|
||||||
-- TODO:
|
-- TODO:
|
||||||
-- upgraded_ts, devices, any email reset stuff?
|
-- upgraded_ts, devices, any email reset stuff?
|
||||||
);
|
);
|
||||||
-- Create sequence for autogenerated numeric usernames
|
|
||||||
CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
|
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertAccountSQL = "" +
|
const insertAccountSQL = "" +
|
||||||
|
|
@ -67,7 +65,7 @@ const selectPasswordHashSQL = "" +
|
||||||
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
|
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
|
||||||
|
|
||||||
const selectNewNumericLocalpartSQL = "" +
|
const selectNewNumericLocalpartSQL = "" +
|
||||||
"SELECT nextval('numeric_username_seq')"
|
"SELECT COALESCE(MAX(localpart::integer), 0) FROM account_accounts WHERE localpart ~ '^[0-9]*$'"
|
||||||
|
|
||||||
type accountsStatements struct {
|
type accountsStatements struct {
|
||||||
insertAccountStmt *sql.Stmt
|
insertAccountStmt *sql.Stmt
|
||||||
|
|
@ -178,5 +176,5 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
|
||||||
stmt = sqlutil.TxStmt(txn, stmt)
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
}
|
}
|
||||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
err = stmt.QueryRowContext(ctx).Scan(&id)
|
||||||
return
|
return id + 1, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,7 @@ const selectDeviceByIDSQL = "" +
|
||||||
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
|
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
|
||||||
|
|
||||||
const selectDevicesByLocalpartSQL = "" +
|
const selectDevicesByLocalpartSQL = "" +
|
||||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2"
|
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
|
||||||
|
|
||||||
const updateDeviceNameSQL = "" +
|
const updateDeviceNameSQL = "" +
|
||||||
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||||
|
|
@ -93,7 +93,7 @@ const deleteDevicesSQL = "" +
|
||||||
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
|
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
|
||||||
|
|
||||||
const selectDevicesByIDSQL = "" +
|
const selectDevicesByIDSQL = "" +
|
||||||
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id = ANY($1)"
|
"SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
|
||||||
|
|
||||||
const updateDeviceLastSeen = "" +
|
const updateDeviceLastSeen = "" +
|
||||||
"UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4"
|
"UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4"
|
||||||
|
|
@ -235,16 +235,20 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
|
||||||
var devices []api.Device
|
var devices []api.Device
|
||||||
|
var dev api.Device
|
||||||
|
var localpart string
|
||||||
|
var lastseents sql.NullInt64
|
||||||
|
var displayName sql.NullString
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var dev api.Device
|
if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil {
|
||||||
var localpart string
|
|
||||||
var displayName sql.NullString
|
|
||||||
if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if displayName.Valid {
|
if displayName.Valid {
|
||||||
dev.DisplayName = displayName.String
|
dev.DisplayName = displayName.String
|
||||||
}
|
}
|
||||||
|
if lastseents.Valid {
|
||||||
|
dev.LastSeenTS = lastseents.Int64
|
||||||
|
}
|
||||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||||
devices = append(devices, dev)
|
devices = append(devices, dev)
|
||||||
}
|
}
|
||||||
|
|
@ -262,10 +266,10 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByLocalpart: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByLocalpart: rows.close() failed")
|
||||||
|
|
||||||
|
var dev api.Device
|
||||||
|
var lastseents sql.NullInt64
|
||||||
|
var id, displayname, ip, useragent sql.NullString
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var dev api.Device
|
|
||||||
var lastseents sql.NullInt64
|
|
||||||
var id, displayname, ip, useragent sql.NullString
|
|
||||||
err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent)
|
err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return devices, err
|
return devices, err
|
||||||
|
|
|
||||||
|
|
@ -577,21 +577,6 @@ func (d *Database) UpdateDevice(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveDevice revokes a device by deleting the entry in the database
|
|
||||||
// matching with the given device ID and user ID localpart.
|
|
||||||
// If the device doesn't exist, it will not return an error
|
|
||||||
// If something went wrong during the deletion, it will return the SQL error.
|
|
||||||
func (d *Database) RemoveDevice(
|
|
||||||
ctx context.Context, deviceID, localpart string,
|
|
||||||
) error {
|
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
|
||||||
if err := d.Devices.DeleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveDevices revokes one or more devices by deleting the entry in the database
|
// RemoveDevices revokes one or more devices by deleting the entry in the database
|
||||||
// matching with the given device IDs and user ID localpart.
|
// matching with the given device IDs and user ID localpart.
|
||||||
// If the devices don't exist, it will not return an error
|
// If the devices don't exist, it will not return an error
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ const selectPasswordHashSQL = "" +
|
||||||
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
||||||
|
|
||||||
const selectNewNumericLocalpartSQL = "" +
|
const selectNewNumericLocalpartSQL = "" +
|
||||||
"SELECT COUNT(localpart) FROM account_accounts"
|
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM account_accounts WHERE CAST(localpart AS INT) <> 0"
|
||||||
|
|
||||||
type accountsStatements struct {
|
type accountsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
|
@ -121,6 +121,7 @@ func (s *accountsStatements) InsertAccount(
|
||||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||||
ServerName: s.serverName,
|
ServerName: s.serverName,
|
||||||
AppServiceID: appserviceID,
|
AppServiceID: appserviceID,
|
||||||
|
AccountType: accountType,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -177,5 +178,8 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
|
||||||
stmt = sqlutil.TxStmt(txn, stmt)
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
}
|
}
|
||||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
err = stmt.QueryRowContext(ctx).Scan(&id)
|
||||||
return
|
if err == sql.ErrNoRows {
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
return id + 1, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -63,7 +63,7 @@ const selectDeviceByIDSQL = "" +
|
||||||
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
|
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
|
||||||
|
|
||||||
const selectDevicesByLocalpartSQL = "" +
|
const selectDevicesByLocalpartSQL = "" +
|
||||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2"
|
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
|
||||||
|
|
||||||
const updateDeviceNameSQL = "" +
|
const updateDeviceNameSQL = "" +
|
||||||
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||||
|
|
@ -78,7 +78,7 @@ const deleteDevicesSQL = "" +
|
||||||
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
|
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
|
||||||
|
|
||||||
const selectDevicesByIDSQL = "" +
|
const selectDevicesByIDSQL = "" +
|
||||||
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)"
|
"SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
|
||||||
|
|
||||||
const updateDeviceLastSeen = "" +
|
const updateDeviceLastSeen = "" +
|
||||||
"UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4"
|
"UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4"
|
||||||
|
|
@ -235,10 +235,10 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
|
||||||
return devices, err
|
return devices, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var dev api.Device
|
||||||
|
var lastseents sql.NullInt64
|
||||||
|
var id, displayname, ip, useragent sql.NullString
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var dev api.Device
|
|
||||||
var lastseents sql.NullInt64
|
|
||||||
var id, displayname, ip, useragent sql.NullString
|
|
||||||
err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent)
|
err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return devices, err
|
return devices, err
|
||||||
|
|
@ -279,16 +279,20 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
|
||||||
var devices []api.Device
|
var devices []api.Device
|
||||||
|
var dev api.Device
|
||||||
|
var localpart string
|
||||||
|
var displayName sql.NullString
|
||||||
|
var lastseents sql.NullInt64
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var dev api.Device
|
if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil {
|
||||||
var localpart string
|
|
||||||
var displayName sql.NullString
|
|
||||||
if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if displayName.Valid {
|
if displayName.Valid {
|
||||||
dev.DisplayName = displayName.String
|
dev.DisplayName = displayName.String
|
||||||
}
|
}
|
||||||
|
if lastseents.Valid {
|
||||||
|
dev.LastSeenTS = lastseents.Int64
|
||||||
|
}
|
||||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||||
devices = append(devices, dev)
|
devices = append(devices, dev)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -28,9 +28,9 @@ import (
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/sqlite3"
|
"github.com/matrix-org/dendrite/userapi/storage/sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
|
// NewUserAPIDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
|
||||||
// and sets postgres connection parameters
|
// and sets postgres connection parameters
|
||||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) {
|
func NewUserAPIDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) {
|
||||||
switch {
|
switch {
|
||||||
case dbProperties.ConnectionString.IsSQLite():
|
case dbProperties.ConnectionString.IsSQLite():
|
||||||
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
|
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
|
||||||
|
|
|
||||||
539
userapi/storage/storage_test.go
Normal file
539
userapi/storage/storage_test.go
Normal file
|
|
@ -0,0 +1,539 @@
|
||||||
|
package storage_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
const loginTokenLifetime = time.Minute
|
||||||
|
|
||||||
|
var (
|
||||||
|
openIDLifetimeMS = time.Minute.Milliseconds()
|
||||||
|
ctx = context.Background()
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
||||||
|
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||||
|
db, err := storage.NewUserAPIDatabase(&config.DatabaseOptions{
|
||||||
|
ConnectionString: config.DataSource(connStr),
|
||||||
|
}, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewUserAPIDatabase returned %s", err)
|
||||||
|
}
|
||||||
|
return db, close
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests storing and getting account data
|
||||||
|
func Test_AccountData(t *testing.T) {
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, close := mustCreateDatabase(t, dbType)
|
||||||
|
defer close()
|
||||||
|
alice := test.NewUser()
|
||||||
|
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
room := test.NewRoom(t, alice)
|
||||||
|
events := room.Events()
|
||||||
|
|
||||||
|
contentRoom := json.RawMessage(fmt.Sprintf(`{"event_id":"%s"}`, events[len(events)-1].EventID()))
|
||||||
|
err = db.SaveAccountData(ctx, localpart, room.ID, "m.fully_read", contentRoom)
|
||||||
|
assert.NoError(t, err, "unable to save account data")
|
||||||
|
|
||||||
|
contentGlobal := json.RawMessage(fmt.Sprintf(`{"recent_rooms":["%s"]}`, room.ID))
|
||||||
|
err = db.SaveAccountData(ctx, localpart, "", "im.vector.setting.breadcrumbs", contentGlobal)
|
||||||
|
assert.NoError(t, err, "unable to save account data")
|
||||||
|
|
||||||
|
accountData, err := db.GetAccountDataByType(ctx, localpart, room.ID, "m.fully_read")
|
||||||
|
assert.NoError(t, err, "unable to get account data by type")
|
||||||
|
assert.Equal(t, contentRoom, accountData)
|
||||||
|
|
||||||
|
globalData, roomData, err := db.GetAccountData(ctx, localpart)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, contentRoom, roomData[room.ID]["m.fully_read"])
|
||||||
|
assert.Equal(t, contentGlobal, globalData["im.vector.setting.breadcrumbs"])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests the creation of accounts
|
||||||
|
func Test_Accounts(t *testing.T) {
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, close := mustCreateDatabase(t, dbType)
|
||||||
|
defer close()
|
||||||
|
alice := test.NewUser()
|
||||||
|
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
|
||||||
|
assert.NoError(t, err, "failed to create account")
|
||||||
|
// verify the newly create account is the same as returned by CreateAccount
|
||||||
|
var accGet *api.Account
|
||||||
|
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "testing")
|
||||||
|
assert.NoError(t, err, "failed to get account by password")
|
||||||
|
assert.Equal(t, accAlice, accGet)
|
||||||
|
accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart)
|
||||||
|
assert.NoError(t, err, "failed to get account by localpart")
|
||||||
|
assert.Equal(t, accAlice, accGet)
|
||||||
|
|
||||||
|
// check account availability
|
||||||
|
available, err := db.CheckAccountAvailability(ctx, aliceLocalpart)
|
||||||
|
assert.NoError(t, err, "failed to checkout account availability")
|
||||||
|
assert.Equal(t, false, available)
|
||||||
|
|
||||||
|
available, err = db.CheckAccountAvailability(ctx, "unusedname")
|
||||||
|
assert.NoError(t, err, "failed to checkout account availability")
|
||||||
|
assert.Equal(t, true, available)
|
||||||
|
|
||||||
|
// get guest account numeric aliceLocalpart
|
||||||
|
first, err := db.GetNewNumericLocalpart(ctx)
|
||||||
|
assert.NoError(t, err, "failed to get new numeric localpart")
|
||||||
|
// Create a new account to verify the numeric localpart is updated
|
||||||
|
_, err = db.CreateAccount(ctx, "", "testing", "", api.AccountTypeGuest)
|
||||||
|
assert.NoError(t, err, "failed to create account")
|
||||||
|
second, err := db.GetNewNumericLocalpart(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Greater(t, second, first)
|
||||||
|
|
||||||
|
// update password for alice
|
||||||
|
err = db.SetPassword(ctx, aliceLocalpart, "newPassword")
|
||||||
|
assert.NoError(t, err, "failed to update password")
|
||||||
|
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
|
||||||
|
assert.NoError(t, err, "failed to get account by new password")
|
||||||
|
assert.Equal(t, accAlice, accGet)
|
||||||
|
|
||||||
|
// deactivate account
|
||||||
|
err = db.DeactivateAccount(ctx, aliceLocalpart)
|
||||||
|
assert.NoError(t, err, "failed to deactivate account")
|
||||||
|
// This should fail now, as the account is deactivated
|
||||||
|
_, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
|
||||||
|
assert.Error(t, err, "expected an error, got none")
|
||||||
|
|
||||||
|
_, err = db.GetAccountByLocalpart(ctx, "unusename")
|
||||||
|
assert.Error(t, err, "expected an error for non existent localpart")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Devices(t *testing.T) {
|
||||||
|
alice := test.NewUser()
|
||||||
|
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
deviceID := util.RandomString(8)
|
||||||
|
accessToken := util.RandomString(16)
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, close := mustCreateDatabase(t, dbType)
|
||||||
|
defer close()
|
||||||
|
|
||||||
|
deviceWithID, err := db.CreateDevice(ctx, localpart, &deviceID, accessToken, nil, "", "")
|
||||||
|
assert.NoError(t, err, "unable to create deviceWithoutID")
|
||||||
|
|
||||||
|
gotDevice, err := db.GetDeviceByID(ctx, localpart, deviceID)
|
||||||
|
assert.NoError(t, err, "unable to get device by id")
|
||||||
|
assert.Equal(t, deviceWithID.ID, gotDevice.ID) // GetDeviceByID doesn't populate all fields
|
||||||
|
|
||||||
|
gotDeviceAccessToken, err := db.GetDeviceByAccessToken(ctx, accessToken)
|
||||||
|
assert.NoError(t, err, "unable to get device by access token")
|
||||||
|
assert.Equal(t, deviceWithID.ID, gotDeviceAccessToken.ID) // GetDeviceByAccessToken doesn't populate all fields
|
||||||
|
|
||||||
|
// create a device without existing device ID
|
||||||
|
accessToken = util.RandomString(16)
|
||||||
|
deviceWithoutID, err := db.CreateDevice(ctx, localpart, nil, accessToken, nil, "", "")
|
||||||
|
assert.NoError(t, err, "unable to create deviceWithoutID")
|
||||||
|
gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, deviceWithoutID.ID)
|
||||||
|
assert.NoError(t, err, "unable to get device by id")
|
||||||
|
assert.Equal(t, deviceWithoutID.ID, gotDeviceWithoutID.ID) // GetDeviceByID doesn't populate all fields
|
||||||
|
|
||||||
|
// Get devices
|
||||||
|
devices, err := db.GetDevicesByLocalpart(ctx, localpart)
|
||||||
|
assert.NoError(t, err, "unable to get devices by localpart")
|
||||||
|
assert.Equal(t, 2, len(devices))
|
||||||
|
deviceIDs := make([]string, 0, len(devices))
|
||||||
|
for _, dev := range devices {
|
||||||
|
deviceIDs = append(deviceIDs, dev.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
devices2, err := db.GetDevicesByID(ctx, deviceIDs)
|
||||||
|
assert.NoError(t, err, "unable to get devices by id")
|
||||||
|
assert.Equal(t, devices, devices2)
|
||||||
|
|
||||||
|
// Update device
|
||||||
|
newName := "new display name"
|
||||||
|
err = db.UpdateDevice(ctx, localpart, deviceWithID.ID, &newName)
|
||||||
|
assert.NoError(t, err, "unable to update device displayname")
|
||||||
|
err = db.UpdateDeviceLastSeen(ctx, localpart, deviceWithID.ID, "127.0.0.1")
|
||||||
|
assert.NoError(t, err, "unable to update device last seen")
|
||||||
|
|
||||||
|
deviceWithID.DisplayName = newName
|
||||||
|
deviceWithID.LastSeenIP = "127.0.0.1"
|
||||||
|
deviceWithID.LastSeenTS = int64(gomatrixserverlib.AsTimestamp(time.Now().Truncate(time.Second)))
|
||||||
|
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
|
||||||
|
assert.NoError(t, err, "unable to get device by id")
|
||||||
|
assert.Equal(t, 2, len(devices))
|
||||||
|
assert.Equal(t, deviceWithID.DisplayName, devices[0].DisplayName)
|
||||||
|
assert.Equal(t, deviceWithID.LastSeenIP, devices[0].LastSeenIP)
|
||||||
|
truncatedTime := gomatrixserverlib.Timestamp(devices[0].LastSeenTS).Time().Truncate(time.Second)
|
||||||
|
assert.Equal(t, gomatrixserverlib.Timestamp(deviceWithID.LastSeenTS), gomatrixserverlib.AsTimestamp(truncatedTime))
|
||||||
|
|
||||||
|
// create one more device and remove the devices step by step
|
||||||
|
newDeviceID := util.RandomString(16)
|
||||||
|
accessToken = util.RandomString(16)
|
||||||
|
_, err = db.CreateDevice(ctx, localpart, &newDeviceID, accessToken, nil, "", "")
|
||||||
|
assert.NoError(t, err, "unable to create new device")
|
||||||
|
|
||||||
|
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
|
||||||
|
assert.NoError(t, err, "unable to get device by id")
|
||||||
|
assert.Equal(t, 3, len(devices))
|
||||||
|
|
||||||
|
err = db.RemoveDevices(ctx, localpart, deviceIDs)
|
||||||
|
assert.NoError(t, err, "unable to remove devices")
|
||||||
|
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
|
||||||
|
assert.NoError(t, err, "unable to get device by id")
|
||||||
|
assert.Equal(t, 1, len(devices))
|
||||||
|
|
||||||
|
deleted, err := db.RemoveAllDevices(ctx, localpart, "")
|
||||||
|
assert.NoError(t, err, "unable to remove all devices")
|
||||||
|
assert.Equal(t, 1, len(deleted))
|
||||||
|
assert.Equal(t, newDeviceID, deleted[0].ID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_KeyBackup(t *testing.T) {
|
||||||
|
alice := test.NewUser()
|
||||||
|
room := test.NewRoom(t, alice)
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, close := mustCreateDatabase(t, dbType)
|
||||||
|
defer close()
|
||||||
|
|
||||||
|
wantAuthData := json.RawMessage("my auth data")
|
||||||
|
wantVersion, err := db.CreateKeyBackup(ctx, alice.ID, "dummyAlgo", wantAuthData)
|
||||||
|
assert.NoError(t, err, "unable to create key backup")
|
||||||
|
// get key backup by version
|
||||||
|
gotVersion, gotAlgo, gotAuthData, _, _, err := db.GetKeyBackup(ctx, alice.ID, wantVersion)
|
||||||
|
assert.NoError(t, err, "unable to get key backup")
|
||||||
|
assert.Equal(t, wantVersion, gotVersion, "backup version mismatch")
|
||||||
|
assert.Equal(t, "dummyAlgo", gotAlgo, "backup algorithm mismatch")
|
||||||
|
assert.Equal(t, wantAuthData, gotAuthData, "backup auth data mismatch")
|
||||||
|
|
||||||
|
// get any key backup
|
||||||
|
gotVersion, gotAlgo, gotAuthData, _, _, err = db.GetKeyBackup(ctx, alice.ID, "")
|
||||||
|
assert.NoError(t, err, "unable to get key backup")
|
||||||
|
assert.Equal(t, wantVersion, gotVersion, "backup version mismatch")
|
||||||
|
assert.Equal(t, "dummyAlgo", gotAlgo, "backup algorithm mismatch")
|
||||||
|
assert.Equal(t, wantAuthData, gotAuthData, "backup auth data mismatch")
|
||||||
|
|
||||||
|
err = db.UpdateKeyBackupAuthData(ctx, alice.ID, wantVersion, json.RawMessage("my updated auth data"))
|
||||||
|
assert.NoError(t, err, "unable to update key backup auth data")
|
||||||
|
|
||||||
|
uploads := []api.InternalKeyBackupSession{
|
||||||
|
{
|
||||||
|
KeyBackupSession: api.KeyBackupSession{
|
||||||
|
IsVerified: true,
|
||||||
|
SessionData: wantAuthData,
|
||||||
|
},
|
||||||
|
RoomID: room.ID,
|
||||||
|
SessionID: "1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
KeyBackupSession: api.KeyBackupSession{},
|
||||||
|
RoomID: room.ID,
|
||||||
|
SessionID: "2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
count, _, err := db.UpsertBackupKeys(ctx, wantVersion, alice.ID, uploads)
|
||||||
|
assert.NoError(t, err, "unable to upsert backup keys")
|
||||||
|
assert.Equal(t, int64(len(uploads)), count, "unexpected backup count")
|
||||||
|
|
||||||
|
// do it again to update a key
|
||||||
|
uploads[1].IsVerified = true
|
||||||
|
count, _, err = db.UpsertBackupKeys(ctx, wantVersion, alice.ID, uploads[1:])
|
||||||
|
assert.NoError(t, err, "unable to upsert backup keys")
|
||||||
|
assert.Equal(t, int64(len(uploads)), count, "unexpected backup count")
|
||||||
|
|
||||||
|
// get backup keys by session id
|
||||||
|
gotBackupKeys, err := db.GetBackupKeys(ctx, wantVersion, alice.ID, room.ID, "1")
|
||||||
|
assert.NoError(t, err, "unable to get backup keys")
|
||||||
|
assert.Equal(t, uploads[0].KeyBackupSession, gotBackupKeys[room.ID]["1"])
|
||||||
|
|
||||||
|
// get backup keys by room id
|
||||||
|
gotBackupKeys, err = db.GetBackupKeys(ctx, wantVersion, alice.ID, room.ID, "")
|
||||||
|
assert.NoError(t, err, "unable to get backup keys")
|
||||||
|
assert.Equal(t, uploads[0].KeyBackupSession, gotBackupKeys[room.ID]["1"])
|
||||||
|
|
||||||
|
gotCount, err := db.CountBackupKeys(ctx, wantVersion, alice.ID)
|
||||||
|
assert.NoError(t, err, "unable to get backup keys count")
|
||||||
|
assert.Equal(t, count, gotCount, "unexpected backup count")
|
||||||
|
|
||||||
|
// finally delete a key
|
||||||
|
exists, err := db.DeleteKeyBackup(ctx, alice.ID, wantVersion)
|
||||||
|
assert.NoError(t, err, "unable to delete key backup")
|
||||||
|
assert.True(t, exists)
|
||||||
|
|
||||||
|
// this key should not exist
|
||||||
|
exists, err = db.DeleteKeyBackup(ctx, alice.ID, "3")
|
||||||
|
assert.NoError(t, err, "unable to delete key backup")
|
||||||
|
assert.False(t, exists)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_LoginToken(t *testing.T) {
|
||||||
|
alice := test.NewUser()
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, close := mustCreateDatabase(t, dbType)
|
||||||
|
defer close()
|
||||||
|
|
||||||
|
// create a new token
|
||||||
|
wantLoginToken := &api.LoginTokenData{UserID: alice.ID}
|
||||||
|
|
||||||
|
gotMetadata, err := db.CreateLoginToken(ctx, wantLoginToken)
|
||||||
|
assert.NoError(t, err, "unable to create login token")
|
||||||
|
assert.NotNil(t, gotMetadata)
|
||||||
|
assert.Equal(t, time.Now().Add(loginTokenLifetime).Truncate(loginTokenLifetime), gotMetadata.Expiration.Truncate(loginTokenLifetime))
|
||||||
|
|
||||||
|
// get the new token
|
||||||
|
gotLoginToken, err := db.GetLoginTokenDataByToken(ctx, gotMetadata.Token)
|
||||||
|
assert.NoError(t, err, "unable to get login token")
|
||||||
|
assert.NotNil(t, gotLoginToken)
|
||||||
|
assert.Equal(t, wantLoginToken, gotLoginToken, "unexpected login token")
|
||||||
|
|
||||||
|
// remove the login token again
|
||||||
|
err = db.RemoveLoginToken(ctx, gotMetadata.Token)
|
||||||
|
assert.NoError(t, err, "unable to remove login token")
|
||||||
|
|
||||||
|
// check if the token was actually deleted
|
||||||
|
_, err = db.GetLoginTokenDataByToken(ctx, gotMetadata.Token)
|
||||||
|
assert.Error(t, err, "expected an error, but got none")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_OpenID(t *testing.T) {
|
||||||
|
alice := test.NewUser()
|
||||||
|
token := util.RandomString(24)
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, close := mustCreateDatabase(t, dbType)
|
||||||
|
defer close()
|
||||||
|
|
||||||
|
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + openIDLifetimeMS
|
||||||
|
expires, err := db.CreateOpenIDToken(ctx, token, alice.ID)
|
||||||
|
assert.NoError(t, err, "unable to create OpenID token")
|
||||||
|
assert.Equal(t, expiresAtMS, expires)
|
||||||
|
|
||||||
|
attributes, err := db.GetOpenIDTokenAttributes(ctx, token)
|
||||||
|
assert.NoError(t, err, "unable to get OpenID token attributes")
|
||||||
|
assert.Equal(t, alice.ID, attributes.UserID)
|
||||||
|
assert.Equal(t, expiresAtMS, attributes.ExpiresAtMS)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Profile(t *testing.T) {
|
||||||
|
alice := test.NewUser()
|
||||||
|
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, close := mustCreateDatabase(t, dbType)
|
||||||
|
defer close()
|
||||||
|
|
||||||
|
// create account, which also creates a profile
|
||||||
|
_, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
|
||||||
|
assert.NoError(t, err, "failed to create account")
|
||||||
|
|
||||||
|
gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart)
|
||||||
|
assert.NoError(t, err, "unable to get profile by localpart")
|
||||||
|
wantProfile := &authtypes.Profile{Localpart: aliceLocalpart}
|
||||||
|
assert.Equal(t, wantProfile, gotProfile)
|
||||||
|
|
||||||
|
// set avatar & displayname
|
||||||
|
wantProfile.DisplayName = "Alice"
|
||||||
|
wantProfile.AvatarURL = "mxc://aliceAvatar"
|
||||||
|
err = db.SetDisplayName(ctx, aliceLocalpart, "Alice")
|
||||||
|
assert.NoError(t, err, "unable to set displayname")
|
||||||
|
err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
|
||||||
|
assert.NoError(t, err, "unable to set avatar url")
|
||||||
|
// verify profile
|
||||||
|
gotProfile, err = db.GetProfileByLocalpart(ctx, aliceLocalpart)
|
||||||
|
assert.NoError(t, err, "unable to get profile by localpart")
|
||||||
|
assert.Equal(t, wantProfile, gotProfile)
|
||||||
|
|
||||||
|
// search profiles
|
||||||
|
searchRes, err := db.SearchProfiles(ctx, "Alice", 2)
|
||||||
|
assert.NoError(t, err, "unable to search profiles")
|
||||||
|
assert.Equal(t, 1, len(searchRes))
|
||||||
|
assert.Equal(t, *wantProfile, searchRes[0])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Pusher(t *testing.T) {
|
||||||
|
alice := test.NewUser()
|
||||||
|
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, close := mustCreateDatabase(t, dbType)
|
||||||
|
defer close()
|
||||||
|
|
||||||
|
appID := util.RandomString(8)
|
||||||
|
var pushKeys []string
|
||||||
|
var gotPushers []api.Pusher
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
pushKey := util.RandomString(8)
|
||||||
|
|
||||||
|
wantPusher := api.Pusher{
|
||||||
|
PushKey: pushKey,
|
||||||
|
Kind: api.HTTPKind,
|
||||||
|
AppID: appID,
|
||||||
|
AppDisplayName: util.RandomString(8),
|
||||||
|
DeviceDisplayName: util.RandomString(8),
|
||||||
|
ProfileTag: util.RandomString(8),
|
||||||
|
Language: util.RandomString(2),
|
||||||
|
}
|
||||||
|
err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart)
|
||||||
|
assert.NoError(t, err, "unable to upsert pusher")
|
||||||
|
|
||||||
|
// check it was actually persisted
|
||||||
|
gotPushers, err = db.GetPushers(ctx, aliceLocalpart)
|
||||||
|
assert.NoError(t, err, "unable to get pushers")
|
||||||
|
assert.Equal(t, i+1, len(gotPushers))
|
||||||
|
assert.Equal(t, wantPusher, gotPushers[i])
|
||||||
|
pushKeys = append(pushKeys, pushKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove single pusher
|
||||||
|
err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart)
|
||||||
|
assert.NoError(t, err, "unable to remove pusher")
|
||||||
|
gotPushers, err := db.GetPushers(ctx, aliceLocalpart)
|
||||||
|
assert.NoError(t, err, "unable to get pushers")
|
||||||
|
assert.Equal(t, 1, len(gotPushers))
|
||||||
|
|
||||||
|
// remove last pusher
|
||||||
|
err = db.RemovePushers(ctx, appID, pushKeys[1])
|
||||||
|
assert.NoError(t, err, "unable to remove pusher")
|
||||||
|
gotPushers, err = db.GetPushers(ctx, aliceLocalpart)
|
||||||
|
assert.NoError(t, err, "unable to get pushers")
|
||||||
|
assert.Equal(t, 0, len(gotPushers))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_ThreePID(t *testing.T) {
|
||||||
|
alice := test.NewUser()
|
||||||
|
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, close := mustCreateDatabase(t, dbType)
|
||||||
|
defer close()
|
||||||
|
threePID := util.RandomString(8)
|
||||||
|
medium := util.RandomString(8)
|
||||||
|
err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, medium)
|
||||||
|
assert.NoError(t, err, "unable to save threepid association")
|
||||||
|
|
||||||
|
// get the stored threepid
|
||||||
|
gotLocalpart, err := db.GetLocalpartForThreePID(ctx, threePID, medium)
|
||||||
|
assert.NoError(t, err, "unable to get localpart for threepid")
|
||||||
|
assert.Equal(t, aliceLocalpart, gotLocalpart)
|
||||||
|
|
||||||
|
threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart)
|
||||||
|
assert.NoError(t, err, "unable to get threepids for localpart")
|
||||||
|
assert.Equal(t, 1, len(threepids))
|
||||||
|
assert.Equal(t, authtypes.ThreePID{
|
||||||
|
Address: threePID,
|
||||||
|
Medium: medium,
|
||||||
|
}, threepids[0])
|
||||||
|
|
||||||
|
// remove threepid association
|
||||||
|
err = db.RemoveThreePIDAssociation(ctx, threePID, medium)
|
||||||
|
assert.NoError(t, err, "unexpected error")
|
||||||
|
|
||||||
|
// verify it was deleted
|
||||||
|
threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart)
|
||||||
|
assert.NoError(t, err, "unable to get threepids for localpart")
|
||||||
|
assert.Equal(t, 0, len(threepids))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Notification(t *testing.T) {
|
||||||
|
alice := test.NewUser()
|
||||||
|
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
room := test.NewRoom(t, alice)
|
||||||
|
room2 := test.NewRoom(t, alice)
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, close := mustCreateDatabase(t, dbType)
|
||||||
|
defer close()
|
||||||
|
// generate some dummy notifications
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
eventID := util.RandomString(16)
|
||||||
|
roomID := room.ID
|
||||||
|
ts := time.Now()
|
||||||
|
if i > 5 {
|
||||||
|
roomID = room2.ID
|
||||||
|
// create some old notifications to test DeleteOldNotifications
|
||||||
|
ts = ts.AddDate(0, -2, 0)
|
||||||
|
}
|
||||||
|
notification := &api.Notification{
|
||||||
|
Actions: []*pushrules.Action{
|
||||||
|
{},
|
||||||
|
},
|
||||||
|
Event: gomatrixserverlib.ClientEvent{
|
||||||
|
Content: gomatrixserverlib.RawJSON("{}"),
|
||||||
|
},
|
||||||
|
Read: false,
|
||||||
|
RoomID: roomID,
|
||||||
|
TS: gomatrixserverlib.AsTimestamp(ts),
|
||||||
|
}
|
||||||
|
err = db.InsertNotification(ctx, aliceLocalpart, eventID, int64(i+1), nil, notification)
|
||||||
|
assert.NoError(t, err, "unable to insert notification")
|
||||||
|
}
|
||||||
|
|
||||||
|
// get notifications
|
||||||
|
count, err := db.GetNotificationCount(ctx, aliceLocalpart, tables.AllNotifications)
|
||||||
|
assert.NoError(t, err, "unable to get notification count")
|
||||||
|
assert.Equal(t, int64(10), count)
|
||||||
|
notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, 0, 15, tables.AllNotifications)
|
||||||
|
assert.NoError(t, err, "unable to get notifications")
|
||||||
|
assert.Equal(t, int64(10), count)
|
||||||
|
assert.Equal(t, 10, len(notifs))
|
||||||
|
// ... for a specific room
|
||||||
|
total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
|
||||||
|
assert.NoError(t, err, "unable to get notifications for room")
|
||||||
|
assert.Equal(t, int64(4), total)
|
||||||
|
|
||||||
|
// mark notification as read
|
||||||
|
affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, room2.ID, 7, true)
|
||||||
|
assert.NoError(t, err, "unable to set notifications read")
|
||||||
|
assert.True(t, affected)
|
||||||
|
|
||||||
|
// this should delete 2 notifications
|
||||||
|
affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, room2.ID, 8)
|
||||||
|
assert.NoError(t, err, "unable to set notifications read")
|
||||||
|
assert.True(t, affected)
|
||||||
|
|
||||||
|
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
|
||||||
|
assert.NoError(t, err, "unable to get notifications for room")
|
||||||
|
assert.Equal(t, int64(2), total)
|
||||||
|
|
||||||
|
// delete old notifications
|
||||||
|
err = db.DeleteOldNotifications(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// this should now return 0 notifications
|
||||||
|
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
|
||||||
|
assert.NoError(t, err, "unable to get notifications for room")
|
||||||
|
assert.Equal(t, int64(0), total)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
@ -23,7 +23,7 @@ import (
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewDatabase(
|
func NewUserAPIDatabase(
|
||||||
dbProperties *config.DatabaseOptions,
|
dbProperties *config.DatabaseOptions,
|
||||||
serverName gomatrixserverlib.ServerName,
|
serverName gomatrixserverlib.ServerName,
|
||||||
bcryptCost int,
|
bcryptCost int,
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, s
|
||||||
MaxOpenConnections: 1,
|
MaxOpenConnections: 1,
|
||||||
MaxIdleConnections: 1,
|
MaxIdleConnections: 1,
|
||||||
}
|
}
|
||||||
accountDB, err := storage.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "")
|
accountDB, err := storage.NewUserAPIDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create account DB: %s", err)
|
t.Fatalf("failed to create account DB: %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue