Merge branch 'matrix-org:main' into main

This commit is contained in:
Emanuele Aliberti 2022-04-20 11:16:38 +02:00 committed by GitHub
commit d0666902f9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
87 changed files with 2310 additions and 1496 deletions

View file

@ -73,6 +73,26 @@ jobs:
timeout-minutes: 5
name: Unit tests (Go ${{ matrix.go }})
runs-on: ubuntu-latest
# Service containers to run with `container-job`
services:
# Label used to access the service container
postgres:
# Docker Hub image
image: postgres:13-alpine
# Provide the password for postgres
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: dendrite
ports:
# Maps tcp port 5432 on service container to the host
- 5432:5432
# Set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
strategy:
fail-fast: false
matrix:
@ -92,6 +112,11 @@ jobs:
restore-keys: |
${{ runner.os }}-go${{ matrix.go }}-test-
- run: go test ./...
env:
POSTGRES_HOST: localhost
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: dendrite
# build Dendrite for linux with different architectures and go versions
build:

View file

@ -52,6 +52,7 @@ import (
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
pineconeConnections "github.com/matrix-org/pinecone/connections"
pineconeMulticast "github.com/matrix-org/pinecone/multicast"
pineconeRouter "github.com/matrix-org/pinecone/router"
pineconeSessions "github.com/matrix-org/pinecone/sessions"
@ -71,11 +72,9 @@ type DendriteMonolith struct {
PineconeRouter *pineconeRouter.Router
PineconeMulticast *pineconeMulticast.Multicast
PineconeQUIC *pineconeSessions.Sessions
PineconeManager *pineconeConnections.ConnectionManager
StorageDirectory string
CacheDirectory string
staticPeerURI string
staticPeerMutex sync.RWMutex
staticPeerAttempt chan struct{}
listener net.Listener
httpServer *http.Server
processContext *process.ProcessContext
@ -104,15 +103,8 @@ func (m *DendriteMonolith) SetMulticastEnabled(enabled bool) {
}
func (m *DendriteMonolith) SetStaticPeer(uri string) {
m.staticPeerMutex.Lock()
m.staticPeerURI = strings.TrimSpace(uri)
m.staticPeerMutex.Unlock()
m.DisconnectType(int(pineconeRouter.PeerTypeRemote))
if uri != "" {
go func() {
m.staticPeerAttempt <- struct{}{}
}()
}
m.PineconeManager.RemovePeers()
m.PineconeManager.AddPeer(strings.TrimSpace(uri))
}
func (m *DendriteMonolith) DisconnectType(peertype int) {
@ -210,43 +202,6 @@ func (m *DendriteMonolith) RegisterDevice(localpart, deviceID string) (string, e
return loginRes.Device.AccessToken, nil
}
func (m *DendriteMonolith) staticPeerConnect() {
connected := map[string]bool{} // URI -> connected?
attempt := func() {
m.staticPeerMutex.RLock()
uri := m.staticPeerURI
m.staticPeerMutex.RUnlock()
if uri == "" {
return
}
for k := range connected {
delete(connected, k)
}
for _, uri := range strings.Split(uri, ",") {
connected[strings.TrimSpace(uri)] = false
}
for _, info := range m.PineconeRouter.Peers() {
connected[info.URI] = true
}
for k, online := range connected {
if !online {
if err := conn.ConnectToPeer(m.PineconeRouter, k); err != nil {
logrus.WithError(err).Error("Failed to connect to static peer")
}
}
}
}
for {
select {
case <-m.processContext.Context().Done():
case <-m.staticPeerAttempt:
attempt()
case <-time.After(time.Second * 5):
attempt()
}
}
}
// nolint:gocyclo
func (m *DendriteMonolith) Start() {
var err error
@ -284,6 +239,7 @@ func (m *DendriteMonolith) Start() {
m.PineconeRouter = pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
m.PineconeQUIC = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), m.PineconeRouter, []string{"matrix"})
m.PineconeMulticast = pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), m.PineconeRouter)
m.PineconeManager = pineconeConnections.NewConnectionManager(m.PineconeRouter)
prefix := hex.EncodeToString(pk)
cfg := &config.Dendrite{}
@ -392,9 +348,6 @@ func (m *DendriteMonolith) Start() {
m.processContext = base.ProcessContext
m.staticPeerAttempt = make(chan struct{}, 1)
go m.staticPeerConnect()
go func() {
m.logger.Info("Listening on ", cfg.Global.ServerName)
m.logger.Fatal(m.httpServer.Serve(m.PineconeQUIC.Protocol("matrix")))

View file

@ -95,10 +95,10 @@ func SaveAccountData(
}
}
if dataType == "m.fully_read" {
if dataType == "m.fully_read" || dataType == "m.push_rules" {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Unable to set read marker"),
JSON: jsonerror.Forbidden(fmt.Sprintf("Unable to modify %q using this API", dataType)),
}
}

View file

@ -64,11 +64,6 @@ const (
sessionIDLength = 24
)
func init() {
// Register prometheus metrics. They must be registered to be exposed.
prometheus.MustRegister(amtRegUsers)
}
// sessionsDict keeps track of completed auth stages for each session.
// It shouldn't be passed by value because it contains a mutex.
type sessionsDict struct {

View file

@ -37,6 +37,7 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/nats-io/nats.go"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
)
@ -60,6 +61,8 @@ func Setup(
extRoomsProvider api.ExtraPublicRoomsProvider,
mscCfg *config.MSCs, natsClient *nats.Conn,
) {
prometheus.MustRegister(amtRegUsers, sendEventDuration)
rateLimits := httputil.NewRateLimits(&cfg.RateLimiting)
userInteractiveAuth := auth.NewUserInteractive(userAPI, cfg)

View file

@ -46,10 +46,6 @@ var (
userRoomSendMutexes sync.Map // (roomID+userID) -> mutex. mutexes to ensure correct ordering of sendEvents
)
func init() {
prometheus.MustRegister(sendEventDuration)
}
var sendEventDuration = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: "dendrite",

View file

@ -52,6 +52,7 @@ func RequestTurnServer(req *http.Request, device *api.Device, cfg *config.Client
if turnConfig.SharedSecret != "" {
expiry := time.Now().Add(duration).Unix()
resp.Username = fmt.Sprintf("%d:%s", expiry, device.UserID)
mac := hmac.New(sha1.New, []byte(turnConfig.SharedSecret))
_, err := mac.Write([]byte(resp.Username))
@ -60,7 +61,6 @@ func RequestTurnServer(req *http.Request, device *api.Device, cfg *config.Client
return jsonerror.InternalServerError()
}
resp.Username = fmt.Sprintf("%d:%s", expiry, device.UserID)
resp.Password = base64.StdEncoding.EncodeToString(mac.Sum(nil))
} else if turnConfig.Username != "" && turnConfig.Password != "" {
resp.Username = turnConfig.Username

View file

@ -1,156 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"flag"
"fmt"
"net/http"
"net/http/httputil"
"net/url"
"os"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
const usage = `Usage: %s
Create a single endpoint URL which clients can be pointed at.
The client-server API in Dendrite is split across multiple processes
which listen on multiple ports. You cannot point a Matrix client at
any of those ports, as there will be unimplemented functionality.
In addition, all client-server API processes start with the additional
path prefix '/api', which Matrix clients will be unaware of.
This tool will proxy requests for all client-server URLs and forward
them to their respective process. It will also add the '/api' path
prefix to incoming requests.
THIS TOOL IS FOR TESTING AND NOT INTENDED FOR PRODUCTION USE.
Arguments:
`
var (
syncServerURL = flag.String("sync-api-server-url", "", "The base URL of the listening 'dendrite-sync-api-server' process. E.g. 'http://localhost:4200'")
clientAPIURL = flag.String("client-api-server-url", "", "The base URL of the listening 'dendrite-client-api-server' process. E.g. 'http://localhost:4321'")
mediaAPIURL = flag.String("media-api-server-url", "", "The base URL of the listening 'dendrite-media-api-server' process. E.g. 'http://localhost:7779'")
bindAddress = flag.String("bind-address", ":8008", "The listening port for the proxy.")
certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS")
keyFile = flag.String("tls-key", "", "The PEM private key to use for TLS")
)
func makeProxy(targetURL string) (*httputil.ReverseProxy, error) {
targetURL = strings.TrimSuffix(targetURL, "/")
// Check that we can parse the URL.
_, err := url.Parse(targetURL)
if err != nil {
return nil, err
}
return &httputil.ReverseProxy{
Director: func(req *http.Request) {
// URL.Path() removes the % escaping from the path.
// The % encoding will be added back when the url is encoded
// when the request is forwarded.
// This means that we will lose any unessecary escaping from the URL.
// Pratically this means that any distinction between '%2F' and '/'
// in the URL will be lost by the time it reaches the target.
path := req.URL.Path
log.WithFields(log.Fields{
"path": path,
"url": targetURL,
"method": req.Method,
}).Print("proxying request")
newURL, err := url.Parse(targetURL)
// Set the path separately as we need to preserve '#' characters
// that would otherwise be interpreted as being the start of a URL
// fragment.
newURL.Path += path
if err != nil {
// We already checked that we can parse the URL
// So this shouldn't ever get hit.
panic(err)
}
// Copy the query parameters from the request.
newURL.RawQuery = req.URL.RawQuery
req.URL = newURL
},
}, nil
}
func main() {
flag.Usage = func() {
fmt.Fprintf(os.Stderr, usage, os.Args[0])
flag.PrintDefaults()
}
flag.Parse()
if *syncServerURL == "" {
flag.Usage()
fmt.Fprintln(os.Stderr, "no --sync-api-server-url specified.")
os.Exit(1)
}
if *clientAPIURL == "" {
flag.Usage()
fmt.Fprintln(os.Stderr, "no --client-api-server-url specified.")
os.Exit(1)
}
if *mediaAPIURL == "" {
flag.Usage()
fmt.Fprintln(os.Stderr, "no --media-api-server-url specified.")
os.Exit(1)
}
syncProxy, err := makeProxy(*syncServerURL)
if err != nil {
panic(err)
}
clientProxy, err := makeProxy(*clientAPIURL)
if err != nil {
panic(err)
}
mediaProxy, err := makeProxy(*mediaAPIURL)
if err != nil {
panic(err)
}
http.Handle("/_matrix/client/r0/sync", syncProxy)
http.Handle("/_matrix/media/v1/", mediaProxy)
http.Handle("/", clientProxy)
srv := &http.Server{
Addr: *bindAddress,
ReadTimeout: 1 * time.Minute, // how long we wait for the client to send the entire request (after connection accept)
WriteTimeout: 5 * time.Minute, // how long the proxy has to write the full response
}
fmt.Println("Proxying requests to:")
fmt.Println(" /_matrix/client/r0/sync => ", *syncServerURL+"/api/_matrix/client/r0/sync")
fmt.Println(" /_matrix/media/v1 => ", *mediaAPIURL+"/api/_matrix/media/v1")
fmt.Println(" /* => ", *clientAPIURL+"/api/*")
fmt.Println("Listening on ", *bindAddress)
if *certFile != "" && *keyFile != "" {
panic(srv.ListenAndServeTLS(*certFile, *keyFile))
} else {
panic(srv.ListenAndServe())
}
}

View file

@ -25,7 +25,6 @@ import (
"net"
"net/http"
"os"
"strings"
"time"
"github.com/gorilla/mux"
@ -47,6 +46,7 @@ import (
"github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/gomatrixserverlib"
pineconeConnections "github.com/matrix-org/pinecone/connections"
pineconeMulticast "github.com/matrix-org/pinecone/multicast"
pineconeRouter "github.com/matrix-org/pinecone/router"
pineconeSessions "github.com/matrix-org/pinecone/sessions"
@ -90,6 +90,13 @@ func main() {
}
pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter)
pManager := pineconeConnections.NewConnectionManager(pRouter)
pMulticast.Start()
if instancePeer != nil && *instancePeer != "" {
pManager.AddPeer(*instancePeer)
}
go func() {
listener, err := net.Listen("tcp", *instanceListen)
@ -119,36 +126,6 @@ func main() {
}
}()
pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter)
pMulticast.Start()
connectToStaticPeer := func() {
connected := map[string]bool{} // URI -> connected?
for _, uri := range strings.Split(*instancePeer, ",") {
connected[strings.TrimSpace(uri)] = false
}
attempt := func() {
for k := range connected {
connected[k] = false
}
for _, info := range pRouter.Peers() {
connected[info.URI] = true
}
for k, online := range connected {
if !online {
if err := conn.ConnectToPeer(pRouter, k); err != nil {
logrus.WithError(err).Error("Failed to connect to static peer")
}
}
}
}
for {
attempt()
time.Sleep(time.Second * 5)
}
}
cfg := &config.Dendrite{}
cfg.Defaults(true)
cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk))
@ -268,7 +245,6 @@ func main() {
Handler: pMux,
}
go connectToStaticPeer()
go func() {
pubkey := pRouter.PublicKey()
logrus.Info("Listening on ", hex.EncodeToString(pubkey[:]))

View file

@ -22,7 +22,6 @@ import (
"encoding/hex"
"fmt"
"syscall/js"
"time"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/appservice"
@ -44,6 +43,7 @@ import (
_ "github.com/matrix-org/go-sqlite3-js"
pineconeConnections "github.com/matrix-org/pinecone/connections"
pineconeRouter "github.com/matrix-org/pinecone/router"
pineconeSessions "github.com/matrix-org/pinecone/sessions"
)
@ -154,6 +154,8 @@ func startup() {
pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
pSessions := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
pManager := pineconeConnections.NewConnectionManager(pRouter)
pManager.AddPeer("wss://pinecone.matrix.org/public")
cfg := &config.Dendrite{}
cfg.Defaults(true)
@ -237,20 +239,4 @@ func startup() {
}
s.ListenAndServe("fetch")
}()
// Connect to the static peer
go func() {
for {
if pRouter.PeerCount(pineconeRouter.PeerTypeRemote) == 0 {
if err := conn.ConnectToPeer(pRouter, publicPeer); err != nil {
logrus.WithError(err).Error("Failed to connect to static peer")
}
}
select {
case <-base.ProcessContext.Context().Done():
return
case <-time.After(time.Second * 5):
}
}
}()
}

View file

@ -1,138 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"flag"
"fmt"
"net/http"
"net/http/httputil"
"net/url"
"os"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
const usage = `Usage: %s
Create a single endpoint URL which remote matrix servers can be pointed at.
The server-server API in Dendrite is split across multiple processes
which listen on multiple ports. You cannot point a Matrix server at
any of those ports, as there will be unimplemented functionality.
In addition, all server-server API processes start with the additional
path prefix '/api', which Matrix servers will be unaware of.
This tool will proxy requests for all server-server URLs and forward
them to their respective process. It will also add the '/api' path
prefix to incoming requests.
THIS TOOL IS FOR TESTING AND NOT INTENDED FOR PRODUCTION USE.
Arguments:
`
var (
federationAPIURL = flag.String("federation-api-url", "", "The base URL of the listening 'dendrite-federation-api-server' process. E.g. 'http://localhost:4200'")
mediaAPIURL = flag.String("media-api-server-url", "", "The base URL of the listening 'dendrite-media-api-server' process. E.g. 'http://localhost:7779'")
bindAddress = flag.String("bind-address", ":8448", "The listening port for the proxy.")
certFile = flag.String("tls-cert", "server.crt", "The PEM formatted X509 certificate to use for TLS")
keyFile = flag.String("tls-key", "server.key", "The PEM private key to use for TLS")
)
func makeProxy(targetURL string) (*httputil.ReverseProxy, error) {
if !strings.HasSuffix(targetURL, "/") {
targetURL += "/"
}
// Check that we can parse the URL.
_, err := url.Parse(targetURL)
if err != nil {
return nil, err
}
return &httputil.ReverseProxy{
Director: func(req *http.Request) {
// URL.Path() removes the % escaping from the path.
// The % encoding will be added back when the url is encoded
// when the request is forwarded.
// This means that we will lose any unessecary escaping from the URL.
// Pratically this means that any distinction between '%2F' and '/'
// in the URL will be lost by the time it reaches the target.
path := req.URL.Path
log.WithFields(log.Fields{
"path": path,
"url": targetURL,
"method": req.Method,
}).Print("proxying request")
newURL, err := url.Parse(targetURL + path)
if err != nil {
// We already checked that we can parse the URL
// So this shouldn't ever get hit.
panic(err)
}
// Copy the query parameters from the request.
newURL.RawQuery = req.URL.RawQuery
req.URL = newURL
},
}, nil
}
func main() {
flag.Usage = func() {
fmt.Fprintf(os.Stderr, usage, os.Args[0])
flag.PrintDefaults()
}
flag.Parse()
if *federationAPIURL == "" {
flag.Usage()
fmt.Fprintln(os.Stderr, "no --federation-api-url specified.")
os.Exit(1)
}
if *mediaAPIURL == "" {
flag.Usage()
fmt.Fprintln(os.Stderr, "no --media-api-server-url specified.")
os.Exit(1)
}
federationProxy, err := makeProxy(*federationAPIURL)
if err != nil {
panic(err)
}
mediaProxy, err := makeProxy(*mediaAPIURL)
if err != nil {
panic(err)
}
http.Handle("/_matrix/media/v1/", mediaProxy)
http.Handle("/", federationProxy)
srv := &http.Server{
Addr: *bindAddress,
ReadTimeout: 1 * time.Minute, // how long we wait for the client to send the entire request (after connection accept)
WriteTimeout: 5 * time.Minute, // how long the proxy has to write the full response
}
fmt.Println("Proxying requests to:")
fmt.Println(" /_matrix/media/v1 => ", *mediaAPIURL+"/api/_matrix/media/v1")
fmt.Println(" /* => ", *federationAPIURL+"/api/*")
fmt.Println("Listening on ", *bindAddress)
panic(srv.ListenAndServeTLS(*certFile, *keyFile))
}

View file

@ -29,6 +29,7 @@ import (
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
)
@ -53,6 +54,10 @@ func Setup(
servers federationAPI.ServersInRoomProvider,
producer *producers.SyncAPIProducer,
) {
prometheus.MustRegister(
pduCountTotal, eduCountTotal,
)
v2keysmux := keyMux.PathPrefix("/v2").Subrouter()
v1fedmux := fedMux.PathPrefix("/v1").Subrouter()
v2fedmux := fedMux.PathPrefix("/v2").Subrouter()

View file

@ -74,12 +74,6 @@ var (
)
)
func init() {
prometheus.MustRegister(
pduCountTotal, eduCountTotal,
)
}
var inFlightTxnsPerOrigin sync.Map // transaction ID -> chan util.JSONResponse
// Send implements /_matrix/federation/v1/send/{txnID}

38
go.mod
View file

@ -12,20 +12,20 @@ require (
github.com/MFAshby/stdemuxerhook v1.0.0
github.com/Masterminds/semver/v3 v3.1.1
github.com/codeclysm/extract v2.2.0+incompatible
github.com/containerd/containerd v1.5.9 // indirect
github.com/docker/docker v20.10.12+incompatible
github.com/containerd/containerd v1.6.2 // indirect
github.com/docker/docker v20.10.14+incompatible
github.com/docker/go-connections v0.4.0
github.com/frankban/quicktest v1.14.0 // indirect
github.com/getsentry/sentry-go v0.12.0
github.com/frankban/quicktest v1.14.3 // indirect
github.com/getsentry/sentry-go v0.13.0
github.com/gologme/log v1.3.0
github.com/google/go-cmp v0.5.6
github.com/google/uuid v1.2.0
github.com/google/go-cmp v0.5.7
github.com/google/uuid v1.3.0
github.com/gorilla/mux v1.8.0
github.com/gorilla/websocket v1.4.2
github.com/gorilla/websocket v1.5.0
github.com/h2non/filetype v1.1.3 // indirect
github.com/hashicorp/golang-lru v0.5.4
github.com/juju/testing v0.0.0-20211215003918-77eb13d6cad2 // indirect
github.com/lib/pq v1.10.4
github.com/juju/testing v0.0.0-20220203020004-a0ff61f03494 // indirect
github.com/lib/pq v1.10.5
github.com/libp2p/go-libp2p v0.13.0
github.com/libp2p/go-libp2p-circuit v0.4.0
github.com/libp2p/go-libp2p-core v0.8.3
@ -36,18 +36,18 @@ require (
github.com/libp2p/go-libp2p-record v0.1.3
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5
github.com/matrix-org/pinecone v0.0.0-20220404141326-e526fa82f79d
github.com/matrix-org/gomatrixserverlib v0.0.0-20220408160933-cf558306b56f
github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.10
github.com/morikuni/aec v1.0.0 // indirect
github.com/nats-io/nats-server/v2 v2.7.4-0.20220309205833-773636c1c5bb
github.com/nats-io/nats.go v1.13.1-0.20220308171302-2f2f6968e98d
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
github.com/ngrok/sqlmw v0.0.0-20211220175533-9d16fdc47b31
github.com/opencontainers/image-spec v1.0.2 // indirect
github.com/opentracing/opentracing-go v1.2.0
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pkg/errors v0.9.1
@ -60,14 +60,14 @@ require (
github.com/uber/jaeger-lib v2.4.1+incompatible
github.com/yggdrasil-network/yggdrasil-go v0.4.3
go.uber.org/atomic v1.9.0
golang.org/x/crypto v0.0.0-20220214200702-86341886e292
golang.org/x/image v0.0.0-20211028202545-6944b10bf410
golang.org/x/mobile v0.0.0-20220325161704-447654d348e3
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd
golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29
golang.org/x/image v0.0.0-20220321031419-a8550c1d254a
golang.org/x/mobile v0.0.0-20220407111146-e579adbbc4a2
golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3
golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 // indirect
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211
gopkg.in/h2non/bimg.v1 v1.1.5
gopkg.in/h2non/bimg.v1 v1.1.9
gopkg.in/yaml.v2 v2.4.0
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
nhooyr.io/websocket v1.8.7
)

445
go.sum

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,86 @@
package caching
import (
"fmt"
"time"
userapi "github.com/matrix-org/dendrite/userapi/api"
)
const (
LazyLoadCacheName = "lazy_load_members"
LazyLoadCacheMaxEntries = 128
LazyLoadCacheMaxUserEntries = 128
LazyLoadCacheMutable = true
LazyLoadCacheMaxAge = time.Minute * 30
)
type LazyLoadCache struct {
// InMemoryLRUCachePartition containing other InMemoryLRUCachePartitions
// with the actual cached members
userCaches *InMemoryLRUCachePartition
}
// NewLazyLoadCache creates a new LazyLoadCache.
func NewLazyLoadCache() (*LazyLoadCache, error) {
cache, err := NewInMemoryLRUCachePartition(
LazyLoadCacheName,
LazyLoadCacheMutable,
LazyLoadCacheMaxEntries,
LazyLoadCacheMaxAge,
true,
)
if err != nil {
return nil, err
}
go cacheCleaner(cache)
return &LazyLoadCache{
userCaches: cache,
}, nil
}
func (c *LazyLoadCache) lazyLoadCacheForUser(device *userapi.Device) (*InMemoryLRUCachePartition, error) {
cacheName := fmt.Sprintf("%s/%s", device.UserID, device.ID)
userCache, ok := c.userCaches.Get(cacheName)
if ok && userCache != nil {
if cache, ok := userCache.(*InMemoryLRUCachePartition); ok {
return cache, nil
}
}
cache, err := NewInMemoryLRUCachePartition(
LazyLoadCacheName,
LazyLoadCacheMutable,
LazyLoadCacheMaxUserEntries,
LazyLoadCacheMaxAge,
false,
)
if err != nil {
return nil, err
}
c.userCaches.Set(cacheName, cache)
go cacheCleaner(cache)
return cache, nil
}
func (c *LazyLoadCache) StoreLazyLoadedUser(device *userapi.Device, roomID, userID, eventID string) {
cache, err := c.lazyLoadCacheForUser(device)
if err != nil {
return
}
cacheKey := fmt.Sprintf("%s/%s/%s/%s", device.UserID, device.ID, roomID, userID)
cache.Set(cacheKey, eventID)
}
func (c *LazyLoadCache) IsLazyLoadedUserCached(device *userapi.Device, roomID, userID string) (string, bool) {
cache, err := c.lazyLoadCacheForUser(device)
if err != nil {
return "", false
}
cacheKey := fmt.Sprintf("%s/%s/%s/%s", device.UserID, device.ID, roomID, userID)
val, ok := cache.Get(cacheKey)
if !ok {
return "", ok
}
return val.(string), ok
}

View file

@ -171,6 +171,7 @@ func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request)
prometheus.CounterOpts{
Name: metricsName,
Help: "Total number of http requests for HTML resources",
Namespace: "dendrite",
},
[]string{"code"},
),
@ -201,7 +202,28 @@ func MakeInternalAPI(metricsName string, f func(*http.Request) util.JSONResponse
h.ServeHTTP(w, req)
}
return http.HandlerFunc(withSpan)
return promhttp.InstrumentHandlerCounter(
promauto.NewCounterVec(
prometheus.CounterOpts{
Name: metricsName + "_requests_total",
Help: "Total number of internal API calls",
Namespace: "dendrite",
},
[]string{"code"},
),
promhttp.InstrumentHandlerResponseSize(
promauto.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: "dendrite",
Name: metricsName + "_response_size_bytes",
Help: "A histogram of response sizes for requests.",
Buckets: []float64{200, 500, 900, 1500, 5000, 15000, 50000, 100000},
},
[]string{},
),
http.HandlerFunc(withSpan),
),
)
}
// MakeFedAPI makes an http.Handler that checks matrix federation authentication.

View file

@ -3,8 +3,6 @@ package pushgateway
import (
"context"
"encoding/json"
"github.com/matrix-org/gomatrixserverlib"
)
// A Client is how interactions with a Push Gateway is done.
@ -50,7 +48,7 @@ type Device struct {
AppID string `json:"app_id"` // Required
Data map[string]interface{} `json:"data"` // Required. UNSPEC: Sytests require this to allow unknown keys.
PushKey string `json:"pushkey"` // Required
PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"`
PushKeyTS int64 `json:"pushkey_ts,omitempty"`
Tweaks map[string]interface{} `json:"tweaks,omitempty"`
}

View file

@ -32,7 +32,7 @@ func AddPublicRoutes(
userAPI userapi.UserInternalAPI,
client *gomatrixserverlib.Client,
) {
mediaDB, err := storage.Open(&cfg.Database)
mediaDB, err := storage.NewMediaAPIDatasource(&cfg.Database)
if err != nil {
logrus.WithError(err).Panicf("failed to connect to media db")
}

View file

@ -22,6 +22,7 @@ import (
"io"
"net/http"
"net/url"
"os"
"path"
"strings"
@ -311,6 +312,26 @@ func (r *uploadRequest) storeFileAndMetadata(
}
go func() {
file, err := os.Open(string(finalPath))
if err != nil {
r.Logger.WithError(err).Error("unable to open file")
return
}
defer file.Close() // nolint: errcheck
// http.DetectContentType only needs 512 bytes
buf := make([]byte, 512)
_, err = file.Read(buf)
if err != nil {
r.Logger.WithError(err).Error("unable to read file")
return
}
// Check if we need to generate thumbnails
fileType := http.DetectContentType(buf)
if !strings.HasPrefix(fileType, "image") {
r.Logger.WithField("contentType", fileType).Debugf("uploaded file is not an image or can not be thumbnailed, not generating thumbnails")
return
}
busy, err := thumbnailer.GenerateThumbnails(
context.Background(), finalPath, thumbnailSizes, r.MediaMetadata,
activeThumbnailGeneration, maxThumbnailGenerators, db, r.Logger,

View file

@ -51,7 +51,7 @@ func Test_uploadRequest_doUpload(t *testing.T) {
_ = os.Mkdir(testdataPath, os.ModePerm)
defer fileutils.RemoveDir(types.Path(testdataPath), nil)
db, err := storage.Open(&config.DatabaseOptions{
db, err := storage.NewMediaAPIDatasource(&config.DatabaseOptions{
ConnectionString: "file::memory:?cache=shared",
MaxOpenConnections: 100,
MaxIdleConnections: 2,

View file

@ -22,9 +22,17 @@ import (
)
type Database interface {
MediaRepository
Thumbnails
}
type MediaRepository interface {
StoreMediaMetadata(ctx context.Context, mediaMetadata *types.MediaMetadata) error
GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error)
GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error)
}
type Thumbnails interface {
StoreThumbnail(ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata) error
GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error)
GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) ([]*types.ThumbnailMetadata, error)

View file

@ -20,6 +20,8 @@ import (
"database/sql"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
@ -69,24 +71,25 @@ type mediaStatements struct {
selectMediaByHashStmt *sql.Stmt
}
func (s *mediaStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(mediaSchema)
func NewPostgresMediaRepositoryTable(db *sql.DB) (tables.MediaRepository, error) {
s := &mediaStatements{}
_, err := db.Exec(mediaSchema)
if err != nil {
return
return nil, err
}
return statementList{
return s, sqlutil.StatementList{
{&s.insertMediaStmt, insertMediaSQL},
{&s.selectMediaStmt, selectMediaSQL},
{&s.selectMediaByHashStmt, selectMediaByHashSQL},
}.prepare(db)
}.Prepare(db)
}
func (s *mediaStatements) insertMedia(
ctx context.Context, mediaMetadata *types.MediaMetadata,
func (s *mediaStatements) InsertMedia(
ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata,
) error {
mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
_, err := s.insertMediaStmt.ExecContext(
mediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
_, err := sqlutil.TxStmtContext(ctx, txn, s.insertMediaStmt).ExecContext(
ctx,
mediaMetadata.MediaID,
mediaMetadata.Origin,
@ -100,14 +103,14 @@ func (s *mediaStatements) insertMedia(
return err
}
func (s *mediaStatements) selectMedia(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
func (s *mediaStatements) SelectMedia(
ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata := types.MediaMetadata{
MediaID: mediaID,
Origin: mediaOrigin,
}
err := s.selectMediaStmt.QueryRowContext(
err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaStmt).QueryRowContext(
ctx, mediaMetadata.MediaID, mediaMetadata.Origin,
).Scan(
&mediaMetadata.ContentType,
@ -120,14 +123,14 @@ func (s *mediaStatements) selectMedia(
return &mediaMetadata, err
}
func (s *mediaStatements) selectMediaByHash(
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
func (s *mediaStatements) SelectMediaByHash(
ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata := types.MediaMetadata{
Base64Hash: mediaHash,
Origin: mediaOrigin,
}
err := s.selectMediaStmt.QueryRowContext(
err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaByHashStmt).QueryRowContext(
ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin,
).Scan(
&mediaMetadata.ContentType,

View file

@ -0,0 +1,46 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
// Import the postgres database driver.
_ "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/shared"
"github.com/matrix-org/dendrite/setup/config"
)
// NewDatabase opens a postgres database.
func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
mediaRepo, err := NewPostgresMediaRepositoryTable(db)
if err != nil {
return nil, err
}
thumbnails, err := NewPostgresThumbnailsTable(db)
if err != nil {
return nil, err
}
return &shared.Database{
MediaRepository: mediaRepo,
Thumbnails: thumbnails,
DB: db,
Writer: sqlutil.NewExclusiveWriter(),
}, nil
}

View file

@ -1,38 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// FIXME: This should be made internal!
package postgres
import (
"database/sql"
)
// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement.
type statementList []struct {
statement **sql.Stmt
sql string
}
// prepare the SQL for each statement in the list and assign the result to the prepared statement.
func (s statementList) prepare(db *sql.DB) (err error) {
for _, statement := range s {
if *statement.statement, err = db.Prepare(statement.sql); err != nil {
return
}
}
return
}

View file

@ -21,6 +21,8 @@ import (
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
@ -63,7 +65,7 @@ SELECT content_type, file_size_bytes, creation_ts FROM mediaapi_thumbnail WHERE
// Note: this selects all thumbnails for a media_origin and media_id
const selectThumbnailsSQL = `
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 ORDER BY creation_ts ASC
`
type thumbnailStatements struct {
@ -72,24 +74,25 @@ type thumbnailStatements struct {
selectThumbnailsStmt *sql.Stmt
}
func (s *thumbnailStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(thumbnailSchema)
func NewPostgresThumbnailsTable(db *sql.DB) (tables.Thumbnails, error) {
s := &thumbnailStatements{}
_, err := db.Exec(thumbnailSchema)
if err != nil {
return
return nil, err
}
return statementList{
return s, sqlutil.StatementList{
{&s.insertThumbnailStmt, insertThumbnailSQL},
{&s.selectThumbnailStmt, selectThumbnailSQL},
{&s.selectThumbnailsStmt, selectThumbnailsSQL},
}.prepare(db)
}.Prepare(db)
}
func (s *thumbnailStatements) insertThumbnail(
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
func (s *thumbnailStatements) InsertThumbnail(
ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata,
) error {
thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
_, err := s.insertThumbnailStmt.ExecContext(
thumbnailMetadata.MediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
_, err := sqlutil.TxStmtContext(ctx, txn, s.insertThumbnailStmt).ExecContext(
ctx,
thumbnailMetadata.MediaMetadata.MediaID,
thumbnailMetadata.MediaMetadata.Origin,
@ -103,8 +106,9 @@ func (s *thumbnailStatements) insertThumbnail(
return err
}
func (s *thumbnailStatements) selectThumbnail(
func (s *thumbnailStatements) SelectThumbnail(
ctx context.Context,
txn *sql.Tx,
mediaID types.MediaID,
mediaOrigin gomatrixserverlib.ServerName,
width, height int,
@ -121,7 +125,7 @@ func (s *thumbnailStatements) selectThumbnail(
ResizeMethod: resizeMethod,
},
}
err := s.selectThumbnailStmt.QueryRowContext(
err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailStmt).QueryRowContext(
ctx,
thumbnailMetadata.MediaMetadata.MediaID,
thumbnailMetadata.MediaMetadata.Origin,
@ -136,10 +140,10 @@ func (s *thumbnailStatements) selectThumbnail(
return &thumbnailMetadata, err
}
func (s *thumbnailStatements) selectThumbnails(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
func (s *thumbnailStatements) SelectThumbnails(
ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) ([]*types.ThumbnailMetadata, error) {
rows, err := s.selectThumbnailsStmt.QueryContext(
rows, err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailsStmt).QueryContext(
ctx, mediaID, mediaOrigin,
)
if err != nil {

View file

@ -1,5 +1,4 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -13,54 +12,38 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
package shared
import (
"context"
"database/sql"
// Import the postgres database driver.
_ "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
// Database is used to store metadata about a repository of media files.
type Database struct {
statements statements
db *sql.DB
}
// Open opens a postgres database.
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
var d Database
var err error
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
if err = d.statements.prepare(d.db); err != nil {
return nil, err
}
return &d, nil
DB *sql.DB
Writer sqlutil.Writer
MediaRepository tables.MediaRepository
Thumbnails tables.Thumbnails
}
// StoreMediaMetadata inserts the metadata about the uploaded media into the database.
// Returns an error if the combination of MediaID and Origin are not unique in the table.
func (d *Database) StoreMediaMetadata(
ctx context.Context, mediaMetadata *types.MediaMetadata,
) error {
return d.statements.media.insertMedia(ctx, mediaMetadata)
func (d Database) StoreMediaMetadata(ctx context.Context, mediaMetadata *types.MediaMetadata) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.MediaRepository.InsertMedia(ctx, txn, mediaMetadata)
})
}
// GetMediaMetadata returns metadata about media stored on this server.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there is no metadata associated with this media.
func (d *Database) GetMediaMetadata(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata, err := d.statements.media.selectMedia(ctx, mediaID, mediaOrigin)
func (d Database) GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) {
mediaMetadata, err := d.MediaRepository.SelectMedia(ctx, nil, mediaID, mediaOrigin)
if err != nil && err == sql.ErrNoRows {
return nil, nil
}
@ -70,10 +53,8 @@ func (d *Database) GetMediaMetadata(
// GetMediaMetadataByHash returns metadata about media stored on this server.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there is no metadata associated with this media.
func (d *Database) GetMediaMetadataByHash(
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata, err := d.statements.media.selectMediaByHash(ctx, mediaHash, mediaOrigin)
func (d Database) GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) {
mediaMetadata, err := d.MediaRepository.SelectMediaByHash(ctx, nil, mediaHash, mediaOrigin)
if err != nil && err == sql.ErrNoRows {
return nil, nil
}
@ -82,40 +63,36 @@ func (d *Database) GetMediaMetadataByHash(
// StoreThumbnail inserts the metadata about the thumbnail into the database.
// Returns an error if the combination of MediaID and Origin are not unique in the table.
func (d *Database) StoreThumbnail(
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
) error {
return d.statements.thumbnail.insertThumbnail(ctx, thumbnailMetadata)
func (d Database) StoreThumbnail(ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Thumbnails.InsertThumbnail(ctx, txn, thumbnailMetadata)
})
}
// GetThumbnail returns metadata about a specific thumbnail.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there is no metadata associated with this thumbnail.
func (d *Database) GetThumbnail(
ctx context.Context,
mediaID types.MediaID,
mediaOrigin gomatrixserverlib.ServerName,
width, height int,
resizeMethod string,
) (*types.ThumbnailMetadata, error) {
thumbnailMetadata, err := d.statements.thumbnail.selectThumbnail(
ctx, mediaID, mediaOrigin, width, height, resizeMethod,
)
if err != nil && err == sql.ErrNoRows {
func (d Database) GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error) {
metadata, err := d.Thumbnails.SelectThumbnail(ctx, nil, mediaID, mediaOrigin, width, height, resizeMethod)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return thumbnailMetadata, err
return nil, err
}
return metadata, err
}
// GetThumbnails returns metadata about all thumbnails for a specific media stored on this server.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there are no thumbnails associated with this media.
func (d *Database) GetThumbnails(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) ([]*types.ThumbnailMetadata, error) {
thumbnails, err := d.statements.thumbnail.selectThumbnails(ctx, mediaID, mediaOrigin)
if err != nil && err == sql.ErrNoRows {
func (d Database) GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) ([]*types.ThumbnailMetadata, error) {
metadatas, err := d.Thumbnails.SelectThumbnails(ctx, nil, mediaID, mediaOrigin)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return thumbnails, err
return nil, err
}
return metadatas, err
}

View file

@ -21,6 +21,7 @@ import (
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
@ -66,35 +67,32 @@ SELECT content_type, file_size_bytes, creation_ts, upload_name, media_id, user_i
type mediaStatements struct {
db *sql.DB
writer sqlutil.Writer
insertMediaStmt *sql.Stmt
selectMediaStmt *sql.Stmt
selectMediaByHashStmt *sql.Stmt
}
func (s *mediaStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db
s.writer = writer
_, err = db.Exec(mediaSchema)
func NewSQLiteMediaRepositoryTable(db *sql.DB) (tables.MediaRepository, error) {
s := &mediaStatements{
db: db,
}
_, err := db.Exec(mediaSchema)
if err != nil {
return
return nil, err
}
return statementList{
return s, sqlutil.StatementList{
{&s.insertMediaStmt, insertMediaSQL},
{&s.selectMediaStmt, selectMediaSQL},
{&s.selectMediaByHashStmt, selectMediaByHashSQL},
}.prepare(db)
}.Prepare(db)
}
func (s *mediaStatements) insertMedia(
ctx context.Context, mediaMetadata *types.MediaMetadata,
func (s *mediaStatements) InsertMedia(
ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata,
) error {
mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertMediaStmt)
_, err := stmt.ExecContext(
mediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
_, err := sqlutil.TxStmtContext(ctx, txn, s.insertMediaStmt).ExecContext(
ctx,
mediaMetadata.MediaID,
mediaMetadata.Origin,
@ -106,17 +104,16 @@ func (s *mediaStatements) insertMedia(
mediaMetadata.UserID,
)
return err
})
}
func (s *mediaStatements) selectMedia(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
func (s *mediaStatements) SelectMedia(
ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata := types.MediaMetadata{
MediaID: mediaID,
Origin: mediaOrigin,
}
err := s.selectMediaStmt.QueryRowContext(
err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaStmt).QueryRowContext(
ctx, mediaMetadata.MediaID, mediaMetadata.Origin,
).Scan(
&mediaMetadata.ContentType,
@ -129,14 +126,14 @@ func (s *mediaStatements) selectMedia(
return &mediaMetadata, err
}
func (s *mediaStatements) selectMediaByHash(
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
func (s *mediaStatements) SelectMediaByHash(
ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata := types.MediaMetadata{
Base64Hash: mediaHash,
Origin: mediaOrigin,
}
err := s.selectMediaStmt.QueryRowContext(
err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaByHashStmt).QueryRowContext(
ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin,
).Scan(
&mediaMetadata.ContentType,

View file

@ -16,23 +16,30 @@
package sqlite3
import (
"database/sql"
// Import the postgres database driver.
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/shared"
"github.com/matrix-org/dendrite/setup/config"
)
type statements struct {
media mediaStatements
thumbnail thumbnailStatements
}
func (s *statements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
if err = s.media.prepare(db, writer); err != nil {
return
// NewDatabase opens a SQLIte database.
func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
if err = s.thumbnail.prepare(db, writer); err != nil {
return
mediaRepo, err := NewSQLiteMediaRepositoryTable(db)
if err != nil {
return nil, err
}
return
thumbnails, err := NewSQLiteThumbnailsTable(db)
if err != nil {
return nil, err
}
return &shared.Database{
MediaRepository: mediaRepo,
Thumbnails: thumbnails,
DB: db,
Writer: sqlutil.NewExclusiveWriter(),
}, nil
}

View file

@ -1,38 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// FIXME: This should be made internal!
package sqlite3
import (
"database/sql"
)
// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement.
type statementList []struct {
statement **sql.Stmt
sql string
}
// prepare the SQL for each statement in the list and assign the result to the prepared statement.
func (s statementList) prepare(db *sql.DB) (err error) {
for _, statement := range s {
if *statement.statement, err = db.Prepare(statement.sql); err != nil {
return
}
}
return
}

View file

@ -1,123 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
// Import the postgres database driver.
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
// Database is used to store metadata about a repository of media files.
type Database struct {
statements statements
db *sql.DB
writer sqlutil.Writer
}
// Open opens a postgres database.
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
d := Database{
writer: sqlutil.NewExclusiveWriter(),
}
var err error
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
if err = d.statements.prepare(d.db, d.writer); err != nil {
return nil, err
}
return &d, nil
}
// StoreMediaMetadata inserts the metadata about the uploaded media into the database.
// Returns an error if the combination of MediaID and Origin are not unique in the table.
func (d *Database) StoreMediaMetadata(
ctx context.Context, mediaMetadata *types.MediaMetadata,
) error {
return d.statements.media.insertMedia(ctx, mediaMetadata)
}
// GetMediaMetadata returns metadata about media stored on this server.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there is no metadata associated with this media.
func (d *Database) GetMediaMetadata(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata, err := d.statements.media.selectMedia(ctx, mediaID, mediaOrigin)
if err != nil && err == sql.ErrNoRows {
return nil, nil
}
return mediaMetadata, err
}
// GetMediaMetadataByHash returns metadata about media stored on this server.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there is no metadata associated with this media.
func (d *Database) GetMediaMetadataByHash(
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata, err := d.statements.media.selectMediaByHash(ctx, mediaHash, mediaOrigin)
if err != nil && err == sql.ErrNoRows {
return nil, nil
}
return mediaMetadata, err
}
// StoreThumbnail inserts the metadata about the thumbnail into the database.
// Returns an error if the combination of MediaID and Origin are not unique in the table.
func (d *Database) StoreThumbnail(
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
) error {
return d.statements.thumbnail.insertThumbnail(ctx, thumbnailMetadata)
}
// GetThumbnail returns metadata about a specific thumbnail.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there is no metadata associated with this thumbnail.
func (d *Database) GetThumbnail(
ctx context.Context,
mediaID types.MediaID,
mediaOrigin gomatrixserverlib.ServerName,
width, height int,
resizeMethod string,
) (*types.ThumbnailMetadata, error) {
thumbnailMetadata, err := d.statements.thumbnail.selectThumbnail(
ctx, mediaID, mediaOrigin, width, height, resizeMethod,
)
if err != nil && err == sql.ErrNoRows {
return nil, nil
}
return thumbnailMetadata, err
}
// GetThumbnails returns metadata about all thumbnails for a specific media stored on this server.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there are no thumbnails associated with this media.
func (d *Database) GetThumbnails(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) ([]*types.ThumbnailMetadata, error) {
thumbnails, err := d.statements.thumbnail.selectThumbnails(ctx, mediaID, mediaOrigin)
if err != nil && err == sql.ErrNoRows {
return nil, nil
}
return thumbnails, err
}

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
@ -54,39 +55,32 @@ SELECT content_type, file_size_bytes, creation_ts FROM mediaapi_thumbnail WHERE
// Note: this selects all thumbnails for a media_origin and media_id
const selectThumbnailsSQL = `
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 ORDER BY creation_ts ASC
`
type thumbnailStatements struct {
db *sql.DB
writer sqlutil.Writer
insertThumbnailStmt *sql.Stmt
selectThumbnailStmt *sql.Stmt
selectThumbnailsStmt *sql.Stmt
}
func (s *thumbnailStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
_, err = db.Exec(thumbnailSchema)
func NewSQLiteThumbnailsTable(db *sql.DB) (tables.Thumbnails, error) {
s := &thumbnailStatements{}
_, err := db.Exec(thumbnailSchema)
if err != nil {
return
return nil, err
}
s.db = db
s.writer = writer
return statementList{
return s, sqlutil.StatementList{
{&s.insertThumbnailStmt, insertThumbnailSQL},
{&s.selectThumbnailStmt, selectThumbnailSQL},
{&s.selectThumbnailsStmt, selectThumbnailsSQL},
}.prepare(db)
}.Prepare(db)
}
func (s *thumbnailStatements) insertThumbnail(
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
) error {
thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertThumbnailStmt)
_, err := stmt.ExecContext(
func (s *thumbnailStatements) InsertThumbnail(ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata) error {
thumbnailMetadata.MediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
_, err := sqlutil.TxStmtContext(ctx, txn, s.insertThumbnailStmt).ExecContext(
ctx,
thumbnailMetadata.MediaMetadata.MediaID,
thumbnailMetadata.MediaMetadata.Origin,
@ -98,11 +92,11 @@ func (s *thumbnailStatements) insertThumbnail(
thumbnailMetadata.ThumbnailSize.ResizeMethod,
)
return err
})
}
func (s *thumbnailStatements) selectThumbnail(
func (s *thumbnailStatements) SelectThumbnail(
ctx context.Context,
txn *sql.Tx,
mediaID types.MediaID,
mediaOrigin gomatrixserverlib.ServerName,
width, height int,
@ -119,7 +113,7 @@ func (s *thumbnailStatements) selectThumbnail(
ResizeMethod: resizeMethod,
},
}
err := s.selectThumbnailStmt.QueryRowContext(
err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailStmt).QueryRowContext(
ctx,
thumbnailMetadata.MediaMetadata.MediaID,
thumbnailMetadata.MediaMetadata.Origin,
@ -134,10 +128,11 @@ func (s *thumbnailStatements) selectThumbnail(
return &thumbnailMetadata, err
}
func (s *thumbnailStatements) selectThumbnails(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
func (s *thumbnailStatements) SelectThumbnails(
ctx context.Context, txn *sql.Tx, mediaID types.MediaID,
mediaOrigin gomatrixserverlib.ServerName,
) ([]*types.ThumbnailMetadata, error) {
rows, err := s.selectThumbnailsStmt.QueryContext(
rows, err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailsStmt).QueryContext(
ctx, mediaID, mediaOrigin,
)
if err != nil {

View file

@ -25,13 +25,13 @@ import (
"github.com/matrix-org/dendrite/setup/config"
)
// Open opens a postgres database.
func Open(dbProperties *config.DatabaseOptions) (Database, error) {
// NewMediaAPIDatasource opens a database connection.
func NewMediaAPIDatasource(dbProperties *config.DatabaseOptions) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.Open(dbProperties)
return sqlite3.NewDatabase(dbProperties)
case dbProperties.ConnectionString.IsPostgres():
return postgres.Open(dbProperties)
return postgres.NewDatabase(dbProperties)
default:
return nil, fmt.Errorf("unexpected database type")
}

View file

@ -0,0 +1,135 @@
package storage_test
import (
"context"
"reflect"
"testing"
"github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
)
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := storage.NewMediaAPIDatasource(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
})
if err != nil {
t.Fatalf("NewSyncServerDatasource returned %s", err)
}
return db, close
}
func TestMediaRepository(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
ctx := context.Background()
t.Run("can insert media & query media", func(t *testing.T) {
metadata := &types.MediaMetadata{
MediaID: "testing",
Origin: "localhost",
ContentType: "image/png",
FileSizeBytes: 10,
UploadName: "upload test",
Base64Hash: "dGVzdGluZw==",
UserID: "@alice:localhost",
}
if err := db.StoreMediaMetadata(ctx, metadata); err != nil {
t.Fatalf("unable to store media metadata: %v", err)
}
// query by media id
gotMetadata, err := db.GetMediaMetadata(ctx, metadata.MediaID, metadata.Origin)
if err != nil {
t.Fatalf("unable to query media metadata: %v", err)
}
if !reflect.DeepEqual(metadata, gotMetadata) {
t.Fatalf("expected metadata %+v, got %v", metadata, gotMetadata)
}
// query by media hash
gotMetadata, err = db.GetMediaMetadataByHash(ctx, metadata.Base64Hash, metadata.Origin)
if err != nil {
t.Fatalf("unable to query media metadata by hash: %v", err)
}
if !reflect.DeepEqual(metadata, gotMetadata) {
t.Fatalf("expected metadata %+v, got %v", metadata, gotMetadata)
}
})
})
}
func TestThumbnailsStorage(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
ctx := context.Background()
t.Run("can insert thumbnails & query media", func(t *testing.T) {
thumbnails := []*types.ThumbnailMetadata{
{
MediaMetadata: &types.MediaMetadata{
MediaID: "testing",
Origin: "localhost",
ContentType: "image/png",
FileSizeBytes: 6,
},
ThumbnailSize: types.ThumbnailSize{
Width: 5,
Height: 5,
ResizeMethod: types.Crop,
},
},
{
MediaMetadata: &types.MediaMetadata{
MediaID: "testing",
Origin: "localhost",
ContentType: "image/png",
FileSizeBytes: 7,
},
ThumbnailSize: types.ThumbnailSize{
Width: 1,
Height: 1,
ResizeMethod: types.Scale,
},
},
}
for i := range thumbnails {
if err := db.StoreThumbnail(ctx, thumbnails[i]); err != nil {
t.Fatalf("unable to store thumbnail metadata: %v", err)
}
}
// query by single thumbnail
gotMetadata, err := db.GetThumbnail(ctx,
thumbnails[0].MediaMetadata.MediaID,
thumbnails[0].MediaMetadata.Origin,
thumbnails[0].ThumbnailSize.Width, thumbnails[0].ThumbnailSize.Height,
thumbnails[0].ThumbnailSize.ResizeMethod,
)
if err != nil {
t.Fatalf("unable to query thumbnail metadata: %v", err)
}
if !reflect.DeepEqual(thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata) {
t.Fatalf("expected metadata %+v, got %+v", thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata)
}
if !reflect.DeepEqual(thumbnails[0].ThumbnailSize, gotMetadata.ThumbnailSize) {
t.Fatalf("expected metadata %+v, got %+v", thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata)
}
// query by all thumbnails
gotMediadatas, err := db.GetThumbnails(ctx, thumbnails[0].MediaMetadata.MediaID, thumbnails[0].MediaMetadata.Origin)
if err != nil {
t.Fatalf("unable to query media metadata by hash: %v", err)
}
if len(gotMediadatas) != len(thumbnails) {
t.Fatalf("expected %d stored thumbnail metadata, got %d", len(thumbnails), len(gotMediadatas))
}
for i := range gotMediadatas {
if !reflect.DeepEqual(thumbnails[i].MediaMetadata, gotMediadatas[i].MediaMetadata) {
t.Fatalf("expected metadata %+v, got %v", thumbnails[i].MediaMetadata, gotMediadatas[i].MediaMetadata)
}
if !reflect.DeepEqual(thumbnails[i].ThumbnailSize, gotMediadatas[i].ThumbnailSize) {
t.Fatalf("expected metadata %+v, got %v", thumbnails[i].ThumbnailSize, gotMediadatas[i].ThumbnailSize)
}
}
})
})
}

View file

@ -22,10 +22,10 @@ import (
)
// Open opens a postgres database.
func Open(dbProperties *config.DatabaseOptions) (Database, error) {
func NewMediaAPIDatasource(dbProperties *config.DatabaseOptions) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.Open(dbProperties)
return sqlite3.NewDatabase(dbProperties)
case dbProperties.ConnectionString.IsPostgres():
return nil, fmt.Errorf("can't use Postgres implementation")
default:

View file

@ -0,0 +1,46 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tables
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
type Thumbnails interface {
InsertThumbnail(ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata) error
SelectThumbnail(
ctx context.Context, txn *sql.Tx,
mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
width, height int,
resizeMethod string,
) (*types.ThumbnailMetadata, error)
SelectThumbnails(
ctx context.Context, txn *sql.Tx, mediaID types.MediaID,
mediaOrigin gomatrixserverlib.ServerName,
) ([]*types.ThumbnailMetadata, error)
}
type MediaRepository interface {
InsertMedia(ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata) error
SelectMedia(ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error)
SelectMediaByHash(
ctx context.Context, txn *sql.Tx,
mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error)
}

View file

@ -45,16 +45,13 @@ type RequestMethod string
// MatrixUserID is a Matrix user ID string in the form @user:domain e.g. @alice:matrix.org
type MatrixUserID string
// UnixMs is the milliseconds since the Unix epoch
type UnixMs int64
// MediaMetadata is metadata associated with a media file
type MediaMetadata struct {
MediaID MediaID
Origin gomatrixserverlib.ServerName
ContentType ContentType
FileSizeBytes FileSizeBytes
CreationTimestamp UnixMs
CreationTimestamp gomatrixserverlib.Timestamp
UploadName Filename
Base64Hash Base64Hash
UserID MatrixUserID

View file

@ -167,6 +167,7 @@ func (r *Inputer) startWorkerForRoom(roomID string) {
// will look to see if we have a worker for that room which has its
// own consumer. If we don't, we'll start one.
func (r *Inputer) Start() error {
prometheus.MustRegister(roomserverInputBackpressure, processRoomEventDuration)
_, err := r.JetStream.Subscribe(
"", // This is blank because we specified it in BindStream.
func(m *nats.Msg) {
@ -421,10 +422,6 @@ func (r *Inputer) WriteOutputEvents(roomID string, updates []api.OutputEvent) er
return nil
}
func init() {
prometheus.MustRegister(roomserverInputBackpressure)
}
var roomserverInputBackpressure = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: "dendrite",

View file

@ -37,10 +37,6 @@ import (
"github.com/sirupsen/logrus"
)
func init() {
prometheus.MustRegister(processRoomEventDuration)
}
// TODO: Does this value make sense?
const MaximumMissingProcessingTime = time.Minute * 2

View file

@ -2,7 +2,6 @@ package input_test
import (
"context"
"fmt"
"os"
"testing"
"time"
@ -12,30 +11,22 @@ import (
"github.com/matrix-org/dendrite/roomserver/internal/input"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
)
func psqlConnectionString() config.DataSource {
user := os.Getenv("POSTGRES_USER")
if user == "" {
user = "dendrite"
}
dbName := os.Getenv("POSTGRES_DB")
if dbName == "" {
dbName = "dendrite"
}
connStr := fmt.Sprintf(
"user=%s dbname=%s sslmode=disable", user, dbName,
)
password := os.Getenv("POSTGRES_PASSWORD")
if password != "" {
connStr += fmt.Sprintf(" password=%s", password)
}
host := os.Getenv("POSTGRES_HOST")
if host != "" {
connStr += fmt.Sprintf(" host=%s", host)
}
return config.DataSource(connStr)
var js nats.JetStreamContext
var jc *nats.Conn
func TestMain(m *testing.M) {
var pc *process.ProcessContext
pc, js, jc = jetstream.PrepareForTests()
code := m.Run()
pc.ShutdownDendrite()
pc.WaitForComponentsToFinish()
os.Exit(code)
}
func TestSingleTransactionOnInput(t *testing.T) {
@ -63,7 +54,7 @@ func TestSingleTransactionOnInput(t *testing.T) {
}
db, err := storage.Open(
&config.DatabaseOptions{
ConnectionString: psqlConnectionString(),
ConnectionString: "",
MaxOpenConnections: 1,
MaxIdleConnections: 1,
},
@ -75,6 +66,8 @@ func TestSingleTransactionOnInput(t *testing.T) {
}
inputter := &input.Inputer{
DB: db,
JetStream: js,
NATSClient: jc,
}
res := &api.InputRoomEventsResponse{}
inputter.InputRoomEvents(

View file

@ -13,12 +13,22 @@ import (
"github.com/sirupsen/logrus"
natsserver "github.com/nats-io/nats-server/v2/server"
"github.com/nats-io/nats.go"
natsclient "github.com/nats-io/nats.go"
)
var natsServer *natsserver.Server
var natsServerMutex sync.Mutex
func PrepareForTests() (*process.ProcessContext, nats.JetStreamContext, *nats.Conn) {
cfg := &config.Dendrite{}
cfg.Defaults(true)
cfg.Global.JetStream.InMemory = true
pc := process.NewProcessContext()
js, jc := Prepare(pc, &cfg.Global.JetStream)
return pc, js, jc
}
func Prepare(process *process.ProcessContext, cfg *config.JetStream) (natsclient.JetStreamContext, *natsclient.Conn) {
// check if we need an in-process NATS Server
if len(cfg.Addresses) != 0 {

View file

@ -36,7 +36,7 @@ import (
type Notifier struct {
lock *sync.RWMutex
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
roomIDToJoinedUsers map[string]userIDSet
roomIDToJoinedUsers map[string]*userIDSet
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
roomIDToPeekingDevices map[string]peekingDeviceSet
// The latest sync position
@ -54,7 +54,7 @@ type Notifier struct {
// the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase).
func NewNotifier() *Notifier {
return &Notifier{
roomIDToJoinedUsers: make(map[string]userIDSet),
roomIDToJoinedUsers: make(map[string]*userIDSet),
roomIDToPeekingDevices: make(map[string]peekingDeviceSet),
userDeviceStreams: make(map[string]map[string]*UserDeviceStream),
lock: &sync.RWMutex{},
@ -262,7 +262,7 @@ func (n *Notifier) SharedUsers(userID string) []string {
func (n *Notifier) _sharedUsers(userID string) []string {
n._sharedUserMap[userID] = struct{}{}
for roomID, users := range n.roomIDToJoinedUsers {
if _, ok := users[userID]; !ok {
if ok := users.isIn(userID); !ok {
continue
}
for _, userID := range n._joinedUsers(roomID) {
@ -282,8 +282,11 @@ func (n *Notifier) IsSharedUser(userA, userB string) bool {
defer n.lock.RUnlock()
var okA, okB bool
for _, users := range n.roomIDToJoinedUsers {
_, okA = users[userA]
_, okB = users[userB]
okA = users.isIn(userA)
if !okA {
continue
}
okB = users.isIn(userB)
if okA && okB {
return true
}
@ -345,11 +348,12 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) {
// This is just the bulk form of addJoinedUser
for roomID, userIDs := range roomIDToUserIDs {
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
n.roomIDToJoinedUsers[roomID] = make(userIDSet, len(userIDs))
n.roomIDToJoinedUsers[roomID] = newUserIDSet(len(userIDs))
}
for _, userID := range userIDs {
n.roomIDToJoinedUsers[roomID].add(userID)
}
n.roomIDToJoinedUsers[roomID].precompute()
}
}
@ -440,16 +444,18 @@ func (n *Notifier) _fetchUserStreams(userID string) []*UserDeviceStream {
func (n *Notifier) _addJoinedUser(roomID, userID string) {
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
n.roomIDToJoinedUsers[roomID] = make(userIDSet)
n.roomIDToJoinedUsers[roomID] = newUserIDSet(8)
}
n.roomIDToJoinedUsers[roomID].add(userID)
n.roomIDToJoinedUsers[roomID].precompute()
}
func (n *Notifier) _removeJoinedUser(roomID, userID string) {
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
n.roomIDToJoinedUsers[roomID] = make(userIDSet)
n.roomIDToJoinedUsers[roomID] = newUserIDSet(8)
}
n.roomIDToJoinedUsers[roomID].remove(userID)
n.roomIDToJoinedUsers[roomID].precompute()
}
func (n *Notifier) JoinedUsers(roomID string) (userIDs []string) {
@ -521,19 +527,52 @@ func (n *Notifier) _removeEmptyUserStreams() {
}
// A string set, mainly existing for improving clarity of structs in this file.
type userIDSet map[string]struct{}
func (s userIDSet) add(str string) {
s[str] = struct{}{}
type userIDSet struct {
sync.Mutex
set map[string]struct{}
precomputed []string
}
func (s userIDSet) remove(str string) {
delete(s, str)
func newUserIDSet(cap int) *userIDSet {
return &userIDSet{
set: make(map[string]struct{}, cap),
precomputed: nil,
}
}
func (s userIDSet) values() (vals []string) {
vals = make([]string, 0, len(s))
for str := range s {
func (s *userIDSet) add(str string) {
s.Lock()
defer s.Unlock()
s.set[str] = struct{}{}
s.precomputed = s.precomputed[:0] // invalidate cache
}
func (s *userIDSet) remove(str string) {
s.Lock()
defer s.Unlock()
delete(s.set, str)
s.precomputed = s.precomputed[:0] // invalidate cache
}
func (s *userIDSet) precompute() {
s.Lock()
defer s.Unlock()
s.precomputed = s.values()
}
func (s *userIDSet) isIn(str string) bool {
s.Lock()
defer s.Unlock()
_, ok := s.set[str]
return ok
}
func (s *userIDSet) values() (vals []string) {
if len(s.precomputed) > 0 {
return s.precomputed // only return if not invalidated
}
vals = make([]string, 0, len(s.set))
for str := range s.set {
vals = append(vals, str)
}
return

View file

@ -165,9 +165,9 @@ func TestCorrectStreamWakeup(t *testing.T) {
go func() {
select {
case <-streamone.signalChannel:
case <-streamone.ch():
awoken <- "one"
case <-streamtwo.signalChannel:
case <-streamtwo.ch():
awoken <- "two"
}
}()

View file

@ -118,6 +118,12 @@ func (s *UserDeviceStream) TimeOfLastNonEmpty() time.Time {
return s.timeOfLastChannel
}
func (s *UserDeviceStream) ch() <-chan struct{} {
s.lock.Lock()
defer s.lock.Unlock()
return s.signalChannel
}
// GetSyncPosition returns last sync position which the UserStream was
// notified about
func (s *UserDeviceStreamListener) GetSyncPosition() types.StreamingToken {

View file

@ -60,7 +60,9 @@ func Context(
Headers: nil,
}
}
filter.Rooms = append(filter.Rooms, roomID)
if filter.Rooms != nil {
*filter.Rooms = append(*filter.Rooms, roomID)
}
ctx := req.Context()
membershipRes := roomserver.QueryMembershipForUserResponse{}

View file

@ -262,12 +262,8 @@ func (r *messagesReq) retrieveEvents() (
clientEvents []gomatrixserverlib.ClientEvent, start,
end types.TopologyToken, err error,
) {
eventFilter := r.filter
// Retrieve the events from the local database.
streamEvents, err := r.db.GetEventsInTopologicalRange(
r.ctx, r.from, r.to, r.roomID, eventFilter.Limit, r.backwardOrdering,
)
streamEvents, err := r.db.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering)
if err != nil {
err = fmt.Errorf("GetEventsInRange: %w", err)
return

View file

@ -104,8 +104,8 @@ type Database interface {
// DeletePeek deletes all peeks for a given room by a given user
// Returns an error if there was a problem communicating with the database.
DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error)
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit.
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error)
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last.
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error)
// EventPositionInTopology returns the depth and stream position of the given event.
EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
// BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events.

View file

@ -47,14 +47,10 @@ const selectBackwardExtremitiesForRoomSQL = "" +
const deleteBackwardExtremitySQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
const deleteBackwardExtremitiesForRoomSQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1"
type backwardExtremitiesStatements struct {
insertBackwardExtremityStmt *sql.Stmt
selectBackwardExtremitiesForRoomStmt *sql.Stmt
deleteBackwardExtremityStmt *sql.Stmt
deleteBackwardExtremitiesForRoomStmt *sql.Stmt
}
func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
@ -72,9 +68,6 @@ func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremiti
if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil {
return nil, err
}
if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil {
return nil, err
}
return s, nil
}
@ -113,10 +106,3 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
return
}
func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID)
return err
}

View file

@ -233,9 +233,10 @@ func (s *currentRoomStateStatements) SelectCurrentState(
excludeEventIDs []string,
) ([]*gomatrixserverlib.HeaderedEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt)
senders, notSenders := getSendersStateFilterFilter(stateFilter)
rows, err := stmt.QueryContext(ctx, roomID,
pq.StringArray(stateFilter.Senders),
pq.StringArray(stateFilter.NotSenders),
pq.StringArray(senders),
pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)),
stateFilter.ContainsURL,

View file

@ -16,21 +16,45 @@ package postgres
import (
"strings"
"github.com/matrix-org/gomatrixserverlib"
)
// filterConvertWildcardToSQL converts wildcards as defined in
// https://matrix.org/docs/spec/client_server/r0.3.0.html#post-matrix-client-r0-user-userid-filter
// to SQL wildcards that can be used with LIKE()
func filterConvertTypeWildcardToSQL(values []string) []string {
func filterConvertTypeWildcardToSQL(values *[]string) []string {
if values == nil {
// Return nil instead of []string{} so IS NULL can work correctly when
// the return value is passed into SQL queries
return nil
}
ret := make([]string, len(values))
for i := range values {
ret[i] = strings.Replace(values[i], "*", "%", -1)
v := *values
ret := make([]string, len(v))
for i := range v {
ret[i] = strings.Replace(v[i], "*", "%", -1)
}
return ret
}
// TODO: Replace when Dendrite uses Go 1.18
func getSendersRoomEventFilter(filter *gomatrixserverlib.RoomEventFilter) (senders []string, notSenders []string) {
if filter.Senders != nil {
senders = *filter.Senders
}
if filter.NotSenders != nil {
notSenders = *filter.NotSenders
}
return senders, notSenders
}
func getSendersStateFilterFilter(filter *gomatrixserverlib.StateFilter) (senders []string, notSenders []string) {
if filter.Senders != nil {
senders = *filter.Senders
}
if filter.NotSenders != nil {
notSenders = *filter.NotSenders
}
return senders, notSenders
}

View file

@ -56,12 +56,6 @@ const upsertMembershipSQL = "" +
" ON CONFLICT ON CONSTRAINT syncapi_memberships_unique" +
" DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6"
const selectMembershipSQL = "" +
"SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" +
" WHERE room_id = $1 AND user_id = $2 AND membership = ANY($3)" +
" ORDER BY stream_pos DESC" +
" LIMIT 1"
const selectMembershipCountSQL = "" +
"SELECT COUNT(*) FROM (" +
" SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" +
@ -69,7 +63,6 @@ const selectMembershipCountSQL = "" +
type membershipsStatements struct {
upsertMembershipStmt *sql.Stmt
selectMembershipStmt *sql.Stmt
selectMembershipCountStmt *sql.Stmt
}
@ -82,9 +75,6 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil {
return nil, err
}
if s.selectMembershipStmt, err = db.Prepare(selectMembershipSQL); err != nil {
return nil, err
}
if s.selectMembershipCountStmt, err = db.Prepare(selectMembershipCountSQL); err != nil {
return nil, err
}
@ -111,14 +101,6 @@ func (s *membershipsStatements) UpsertMembership(
return err
}
func (s *membershipsStatements) SelectMembership(
ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string,
) (eventID string, streamPos, topologyPos types.StreamPosition, err error) {
stmt := sqlutil.TxStmt(txn, s.selectMembershipStmt)
err = stmt.QueryRowContext(ctx, roomID, userID, memberships).Scan(&eventID, &streamPos, &topologyPos)
return
}
func (s *membershipsStatements) SelectMembershipCount(
ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition,
) (count int, err error) {

View file

@ -81,6 +81,15 @@ const insertEventSQL = "" +
const selectEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)"
const selectEventsWithFilterSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" +
" AND ( $2::text[] IS NULL OR sender = ANY($2) )" +
" AND ( $3::text[] IS NULL OR NOT(sender = ANY($3)) )" +
" AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" +
" AND ( $5::text[] IS NULL OR NOT(type LIKE ANY($5)) )" +
" AND ( $6::bool IS NULL OR contains_url = $6 )" +
" LIMIT $7"
const selectRecentEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" +
@ -153,6 +162,7 @@ const selectContextAfterEventSQL = "" +
type outputRoomEventsStatements struct {
insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt
selectEventsWitFilterStmt *sql.Stmt
selectMaxEventIDStmt *sql.Stmt
selectRecentEventsStmt *sql.Stmt
selectRecentEventsForSyncStmt *sql.Stmt
@ -174,6 +184,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
return s, sqlutil.StatementList{
{&s.insertEventStmt, insertEventSQL},
{&s.selectEventsStmt, selectEventsSQL},
{&s.selectEventsWitFilterStmt, selectEventsWithFilterSQL},
{&s.selectMaxEventIDStmt, selectMaxEventIDSQL},
{&s.selectRecentEventsStmt, selectRecentEventsSQL},
{&s.selectRecentEventsForSyncStmt, selectRecentEventsForSyncSQL},
@ -204,11 +215,11 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
stateFilter *gomatrixserverlib.StateFilter, roomIDs []string,
) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt)
senders, notSenders := getSendersStateFilterFilter(stateFilter)
rows, err := stmt.QueryContext(
ctx, r.Low(), r.High(), pq.StringArray(roomIDs),
pq.StringArray(stateFilter.Senders),
pq.StringArray(stateFilter.NotSenders),
pq.StringArray(senders),
pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)),
stateFilter.ContainsURL,
@ -310,7 +321,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
// Parse content as JSON and search for an "url" key
containsURL := false
var content map[string]interface{}
if json.Unmarshal(event.Content(), &content) != nil {
if json.Unmarshal(event.Content(), &content) == nil {
// Set containsURL to true if url is present
_, containsURL = content["url"]
}
@ -353,10 +364,11 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
} else {
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt)
}
senders, notSenders := getSendersRoomEventFilter(eventFilter)
rows, err := stmt.QueryContext(
ctx, roomID, r.Low(), r.High(),
pq.StringArray(eventFilter.Senders),
pq.StringArray(eventFilter.NotSenders),
pq.StringArray(senders),
pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)),
eventFilter.Limit+1,
@ -398,11 +410,12 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
ctx context.Context, txn *sql.Tx,
roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
) ([]types.StreamEvent, error) {
senders, notSenders := getSendersRoomEventFilter(eventFilter)
stmt := sqlutil.TxStmt(txn, s.selectEarlyEventsStmt)
rows, err := stmt.QueryContext(
ctx, roomID, r.Low(), r.High(),
pq.StringArray(eventFilter.Senders),
pq.StringArray(eventFilter.NotSenders),
pq.StringArray(senders),
pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)),
eventFilter.Limit,
@ -427,15 +440,52 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
// selectEvents returns the events for the given event IDs. If an event is
// missing from the database, it will be omitted.
func (s *outputRoomEventsStatements) SelectEvents(
ctx context.Context, txn *sql.Tx, eventIDs []string,
ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool,
) ([]types.StreamEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectEventsStmt)
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
var (
stmt *sql.Stmt
rows *sql.Rows
err error
)
if filter == nil {
stmt = sqlutil.TxStmt(txn, s.selectEventsStmt)
rows, err = stmt.QueryContext(ctx, pq.StringArray(eventIDs))
} else {
senders, notSenders := getSendersRoomEventFilter(filter)
stmt = sqlutil.TxStmt(txn, s.selectEventsWitFilterStmt)
rows, err = stmt.QueryContext(ctx,
pq.StringArray(eventIDs),
pq.StringArray(senders),
pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
filter.ContainsURL,
filter.Limit,
)
}
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
return rowsToStreamEvents(rows)
streamEvents, err := rowsToStreamEvents(rows)
if err != nil {
return nil, err
}
if preserveOrder {
eventMap := make(map[string]types.StreamEvent)
for _, ev := range streamEvents {
eventMap[ev.EventID()] = ev
}
var returnEvents []types.StreamEvent
for _, eventID := range eventIDs {
ev, ok := eventMap[eventID]
if ok {
returnEvents = append(returnEvents, ev)
}
}
return returnEvents, nil
}
return streamEvents, nil
}
func (s *outputRoomEventsStatements) DeleteEventsForRoom(
@ -462,10 +512,11 @@ func (s *outputRoomEventsStatements) SelectContextEvent(ctx context.Context, txn
func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter,
) (evts []*gomatrixserverlib.HeaderedEvent, err error) {
senders, notSenders := getSendersRoomEventFilter(filter)
rows, err := sqlutil.TxStmt(txn, s.selectContextBeforeEventStmt).QueryContext(
ctx, roomID, id, filter.Limit,
pq.StringArray(filter.Senders),
pq.StringArray(filter.NotSenders),
pq.StringArray(senders),
pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
)
@ -494,10 +545,11 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
func (s *outputRoomEventsStatements) SelectContextAfterEvent(
ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter,
) (lastID int, evts []*gomatrixserverlib.HeaderedEvent, err error) {
senders, notSenders := getSendersRoomEventFilter(filter)
rows, err := sqlutil.TxStmt(txn, s.selectContextAfterEventStmt).QueryContext(
ctx, roomID, id, filter.Limit,
pq.StringArray(filter.Senders),
pq.StringArray(filter.NotSenders),
pq.StringArray(senders),
pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
)

View file

@ -73,9 +73,6 @@ const selectMaxPositionInTopologySQL = "" +
"SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" +
") ORDER BY stream_position DESC LIMIT 1"
const deleteTopologyForRoomSQL = "" +
"DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1"
const selectStreamToTopologicalPositionAscSQL = "" +
"SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;"
@ -88,7 +85,6 @@ type outputRoomEventsTopologyStatements struct {
selectEventIDsInRangeDESCStmt *sql.Stmt
selectPositionInTopologyStmt *sql.Stmt
selectMaxPositionInTopologyStmt *sql.Stmt
deleteTopologyForRoomStmt *sql.Stmt
selectStreamToTopologicalPositionAscStmt *sql.Stmt
selectStreamToTopologicalPositionDescStmt *sql.Stmt
}
@ -114,9 +110,6 @@ func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) {
if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil {
return nil, err
}
if s.deleteTopologyForRoomStmt, err = db.Prepare(deleteTopologyForRoomSQL); err != nil {
return nil, err
}
if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil {
return nil, err
}
@ -148,9 +141,9 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
// is requested or not.
var stmt *sql.Stmt
if chronologicalOrder {
stmt = s.selectEventIDsInRangeASCStmt
stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeASCStmt)
} else {
stmt = s.selectEventIDsInRangeDESCStmt
stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeDESCStmt)
}
// Query the event IDs.
@ -203,10 +196,3 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
return
}
func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID)
return err
}

View file

@ -150,7 +150,7 @@ func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, stre
// Returns an error if there was a problem talking with the database.
// Does not include any transaction IDs in the returned events.
func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs)
streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, nil, false)
if err != nil {
return nil, err
}
@ -312,7 +312,7 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e
// Check if we have all of the event's previous events. If an event is
// missing, add it to the room's backward extremities.
prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs())
prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), nil, false)
if err != nil {
return err
}
@ -429,7 +429,8 @@ func (d *Database) updateRoomState(
func (d *Database) GetEventsInTopologicalRange(
ctx context.Context,
from, to *types.TopologyToken,
roomID string, limit int,
roomID string,
filter *gomatrixserverlib.RoomEventFilter,
backwardOrdering bool,
) (events []types.StreamEvent, err error) {
var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition
@ -450,14 +451,14 @@ func (d *Database) GetEventsInTopologicalRange(
// Select the event IDs from the defined range.
var eIDs []string
eIDs, err = d.Topology.SelectEventIDsInRange(
ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, limit, !backwardOrdering,
ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering,
)
if err != nil {
return
}
// Retrieve the events' contents using their IDs.
events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs)
events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs, filter, true)
return
}
@ -619,7 +620,7 @@ func (d *Database) fetchMissingStateEvents(
) ([]types.StreamEvent, error) {
// Fetch from the events table first so we pick up the stream ID for the
// event.
events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs)
events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, nil, false)
if err != nil {
return nil, err
}

View file

@ -41,23 +41,23 @@ const insertAccountDataSQL = "" +
" ON CONFLICT (user_id, room_id, type) DO UPDATE" +
" SET id = $5"
// further parameters are added by prepareWithFilters
const selectAccountDataInRangeSQL = "" +
"SELECT room_id, type FROM syncapi_account_data_type" +
" WHERE user_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id ASC"
" WHERE user_id = $1 AND id > $2 AND id <= $3"
const selectMaxAccountDataIDSQL = "" +
"SELECT MAX(id) FROM syncapi_account_data_type"
type accountDataStatements struct {
db *sql.DB
streamIDStatements *streamIDStatements
streamIDStatements *StreamIDStatements
insertAccountDataStmt *sql.Stmt
selectMaxAccountDataIDStmt *sql.Stmt
selectAccountDataInRangeStmt *sql.Stmt
}
func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) {
func NewSqliteAccountDataTable(db *sql.DB, streamID *StreamIDStatements) (tables.AccountData, error) {
s := &accountDataStatements{
db: db,
streamIDStatements: streamID,
@ -94,18 +94,24 @@ func (s *accountDataStatements) SelectAccountDataInRange(
ctx context.Context,
userID string,
r types.Range,
accountDataFilterPart *gomatrixserverlib.EventFilter,
filter *gomatrixserverlib.EventFilter,
) (data map[string][]string, err error) {
data = make(map[string][]string)
stmt, params, err := prepareWithFilters(
s.db, nil, selectAccountDataInRangeSQL,
[]interface{}{
userID, r.Low(), r.High(),
},
filter.Senders, filter.NotSenders,
filter.Types, filter.NotTypes,
[]string{}, nil, filter.Limit, FilterOrderAsc)
rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High())
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed")
var entries int
for rows.Next() {
var dataType string
var roomID string
@ -114,31 +120,11 @@ func (s *accountDataStatements) SelectAccountDataInRange(
return
}
// check if we should add this by looking at the filter.
// It would be nice if we could do this in SQL-land, but the mix of variadic
// and positional parameters makes the query annoyingly hard to do, it's easier
// and clearer to do it in Go-land. If there are no filters for [not]types then
// this gets skipped.
for _, includeType := range accountDataFilterPart.Types {
if includeType != dataType { // TODO: wildcard support
continue
}
}
for _, excludeType := range accountDataFilterPart.NotTypes {
if excludeType == dataType { // TODO: wildcard support
continue
}
}
if len(data[roomID]) > 0 {
data[roomID] = append(data[roomID], dataType)
} else {
data[roomID] = []string{dataType}
}
entries++
if entries >= accountDataFilterPart.Limit {
break
}
}
return data, nil

View file

@ -47,15 +47,11 @@ const selectBackwardExtremitiesForRoomSQL = "" +
const deleteBackwardExtremitySQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
const deleteBackwardExtremitiesForRoomSQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1"
type backwardExtremitiesStatements struct {
db *sql.DB
insertBackwardExtremityStmt *sql.Stmt
selectBackwardExtremitiesForRoomStmt *sql.Stmt
deleteBackwardExtremityStmt *sql.Stmt
deleteBackwardExtremitiesForRoomStmt *sql.Stmt
}
func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
@ -75,9 +71,6 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities
if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil {
return nil, err
}
if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil {
return nil, err
}
return s, nil
}
@ -116,10 +109,3 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
return err
}
func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID)
return err
}

View file

@ -90,7 +90,7 @@ const selectEventsWithEventIDsSQL = "" +
type currentRoomStateStatements struct {
db *sql.DB
streamIDStatements *streamIDStatements
streamIDStatements *StreamIDStatements
upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt
deleteRoomStateForRoomStmt *sql.Stmt
@ -100,7 +100,7 @@ type currentRoomStateStatements struct {
selectStateEventStmt *sql.Stmt
}
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
s := &currentRoomStateStatements{
db: db,
streamIDStatements: streamID,
@ -220,7 +220,7 @@ func (s *currentRoomStateStatements) SelectCurrentState(
},
stateFilter.Senders, stateFilter.NotSenders,
stateFilter.Types, stateFilter.NotTypes,
excludeEventIDs, stateFilter.Limit, FilterOrderNone,
excludeEventIDs, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderNone,
)
if err != nil {
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)

View file

@ -25,33 +25,52 @@ const (
// parts.
func prepareWithFilters(
db *sql.DB, txn *sql.Tx, query string, params []interface{},
senders, notsenders, types, nottypes []string, excludeEventIDs []string,
limit int, order FilterOrder,
senders, notsenders, types, nottypes *[]string, excludeEventIDs []string,
containsURL *bool, limit int, order FilterOrder,
) (*sql.Stmt, []interface{}, error) {
offset := len(params)
if count := len(senders); count > 0 {
if senders != nil {
if count := len(*senders); count > 0 {
query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range senders {
for _, v := range *senders {
params, offset = append(params, v), offset+1
}
} else {
query += ` AND sender = ""`
}
if count := len(notsenders); count > 0 {
}
if notsenders != nil {
if count := len(*notsenders); count > 0 {
query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range notsenders {
for _, v := range *notsenders {
params, offset = append(params, v), offset+1
}
} else {
query += ` AND sender NOT = ""`
}
if count := len(types); count > 0 {
}
if types != nil {
if count := len(*types); count > 0 {
query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range types {
for _, v := range *types {
params, offset = append(params, v), offset+1
}
} else {
query += ` AND type = ""`
}
if count := len(nottypes); count > 0 {
}
if nottypes != nil {
if count := len(*nottypes); count > 0 {
query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range nottypes {
for _, v := range *nottypes {
params, offset = append(params, v), offset+1
}
} else {
query += ` AND type NOT = ""`
}
}
if containsURL != nil {
query += fmt.Sprintf(" AND contains_url = %v", *containsURL)
}
if count := len(excludeEventIDs); count > 0 {
query += " AND event_id NOT IN " + sqlutil.QueryVariadicOffset(count, offset)

View file

@ -59,14 +59,14 @@ const selectMaxInviteIDSQL = "" +
type inviteEventsStatements struct {
db *sql.DB
streamIDStatements *streamIDStatements
streamIDStatements *StreamIDStatements
insertInviteEventStmt *sql.Stmt
selectInviteEventsInRangeStmt *sql.Stmt
deleteInviteEventStmt *sql.Stmt
selectMaxInviteIDStmt *sql.Stmt
}
func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) {
func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Invites, error) {
s := &inviteEventsStatements{
db: db,
streamIDStatements: streamID,

View file

@ -18,7 +18,6 @@ import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
@ -57,12 +56,6 @@ const upsertMembershipSQL = "" +
" ON CONFLICT (room_id, user_id, membership)" +
" DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6"
const selectMembershipSQL = "" +
"SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" +
" WHERE room_id = $1 AND user_id = $2 AND membership IN ($3)" +
" ORDER BY stream_pos DESC" +
" LIMIT 1"
const selectMembershipCountSQL = "" +
"SELECT COUNT(*) FROM (" +
" SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" +
@ -111,22 +104,6 @@ func (s *membershipsStatements) UpsertMembership(
return err
}
func (s *membershipsStatements) SelectMembership(
ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string,
) (eventID string, streamPos, topologyPos types.StreamPosition, err error) {
params := []interface{}{roomID, userID}
for _, membership := range memberships {
params = append(params, membership)
}
orig := strings.Replace(selectMembershipSQL, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1)
stmt, err := s.db.Prepare(orig)
if err != nil {
return "", 0, 0, err
}
err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos)
return
}
func (s *membershipsStatements) SelectMembershipCount(
ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition,
) (count int, err error) {

View file

@ -58,7 +58,7 @@ const insertEventSQL = "" +
"ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)"
const selectEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1"
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id IN ($1)"
const selectRecentEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
@ -111,9 +111,8 @@ const selectContextAfterEventSQL = "" +
type outputRoomEventsStatements struct {
db *sql.DB
streamIDStatements *streamIDStatements
streamIDStatements *StreamIDStatements
insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt
selectMaxEventIDStmt *sql.Stmt
updateEventJSONStmt *sql.Stmt
deleteEventsForRoomStmt *sql.Stmt
@ -122,7 +121,7 @@ type outputRoomEventsStatements struct {
selectContextAfterEventStmt *sql.Stmt
}
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Events, error) {
s := &outputRoomEventsStatements{
db: db,
streamIDStatements: streamID,
@ -133,7 +132,6 @@ func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Even
}
return s, sqlutil.StatementList{
{&s.insertEventStmt, insertEventSQL},
{&s.selectEventsStmt, selectEventsSQL},
{&s.selectMaxEventIDStmt, selectMaxEventIDSQL},
{&s.updateEventJSONStmt, updateEventJSONSQL},
{&s.deleteEventsForRoomStmt, deleteEventsForRoomSQL},
@ -170,7 +168,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
s.db, txn, stmtSQL, inputParams,
stateFilter.Senders, stateFilter.NotSenders,
stateFilter.Types, stateFilter.NotTypes,
nil, stateFilter.Limit, FilterOrderAsc,
nil, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderAsc,
)
if err != nil {
return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err)
@ -279,7 +277,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
// Parse content as JSON and search for an "url" key
containsURL := false
var content map[string]interface{}
if json.Unmarshal(event.Content(), &content) != nil {
if json.Unmarshal(event.Content(), &content) == nil {
// Set containsURL to true if url is present
_, containsURL = content["url"]
}
@ -347,7 +345,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
},
eventFilter.Senders, eventFilter.NotSenders,
eventFilter.Types, eventFilter.NotTypes,
nil, eventFilter.Limit+1, FilterOrderDesc,
nil, eventFilter.ContainsURL, eventFilter.Limit+1, FilterOrderDesc,
)
if err != nil {
return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err)
@ -395,7 +393,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
},
eventFilter.Senders, eventFilter.NotSenders,
eventFilter.Types, eventFilter.NotTypes,
nil, eventFilter.Limit, FilterOrderAsc,
nil, eventFilter.ContainsURL, eventFilter.Limit, FilterOrderAsc,
)
if err != nil {
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
@ -421,21 +419,50 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
// selectEvents returns the events for the given event IDs. If an event is
// missing from the database, it will be omitted.
func (s *outputRoomEventsStatements) SelectEvents(
ctx context.Context, txn *sql.Tx, eventIDs []string,
ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool,
) ([]types.StreamEvent, error) {
var returnEvents []types.StreamEvent
stmt := sqlutil.TxStmt(txn, s.selectEventsStmt)
for _, eventID := range eventIDs {
rows, err := stmt.QueryContext(ctx, eventID)
iEventIDs := make([]interface{}, len(eventIDs))
for i := range eventIDs {
iEventIDs[i] = eventIDs[i]
}
selectSQL := strings.Replace(selectEventsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1)
if filter == nil {
filter = &gomatrixserverlib.RoomEventFilter{Limit: 20}
}
stmt, params, err := prepareWithFilters(
s.db, txn, selectSQL, iEventIDs,
filter.Senders, filter.NotSenders,
filter.Types, filter.NotTypes,
nil, filter.ContainsURL, filter.Limit, FilterOrderAsc,
)
if err != nil {
return nil, err
}
if streamEvents, err := rowsToStreamEvents(rows); err == nil {
returnEvents = append(returnEvents, streamEvents...)
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
streamEvents, err := rowsToStreamEvents(rows)
if err != nil {
return nil, err
}
if preserveOrder {
var returnEvents []types.StreamEvent
eventMap := make(map[string]types.StreamEvent)
for _, ev := range streamEvents {
eventMap[ev.EventID()] = ev
}
for _, eventID := range eventIDs {
ev, ok := eventMap[eventID]
if ok {
returnEvents = append(returnEvents, ev)
}
internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
}
return returnEvents, nil
}
return streamEvents, nil
}
func (s *outputRoomEventsStatements) DeleteEventsForRoom(
@ -507,7 +534,7 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
},
filter.Senders, filter.NotSenders,
filter.Types, filter.NotTypes,
nil, filter.Limit, FilterOrderDesc,
nil, filter.ContainsURL, filter.Limit, FilterOrderDesc,
)
rows, err := stmt.QueryContext(ctx, params...)
@ -543,7 +570,7 @@ func (s *outputRoomEventsStatements) SelectContextAfterEvent(
},
filter.Senders, filter.NotSenders,
filter.Types, filter.NotTypes,
nil, filter.Limit, FilterOrderAsc,
nil, filter.ContainsURL, filter.Limit, FilterOrderAsc,
)
rows, err := stmt.QueryContext(ctx, params...)

View file

@ -78,7 +78,6 @@ type outputRoomEventsTopologyStatements struct {
selectEventIDsInRangeDESCStmt *sql.Stmt
selectPositionInTopologyStmt *sql.Stmt
selectMaxPositionInTopologyStmt *sql.Stmt
deleteTopologyForRoomStmt *sql.Stmt
selectStreamToTopologicalPositionAscStmt *sql.Stmt
selectStreamToTopologicalPositionDescStmt *sql.Stmt
}
@ -191,10 +190,3 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
return
}
func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID)
return err
}

View file

@ -66,7 +66,7 @@ const selectMaxPeekIDSQL = "" +
type peekStatements struct {
db *sql.DB
streamIDStatements *streamIDStatements
streamIDStatements *StreamIDStatements
insertPeekStmt *sql.Stmt
deletePeekStmt *sql.Stmt
deletePeeksStmt *sql.Stmt
@ -75,7 +75,7 @@ type peekStatements struct {
selectMaxPeekIDStmt *sql.Stmt
}
func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks, error) {
func NewSqlitePeeksTable(db *sql.DB, streamID *StreamIDStatements) (tables.Peeks, error) {
_, err := db.Exec(peeksSchema)
if err != nil {
return nil, err

View file

@ -75,7 +75,7 @@ const selectPresenceAfter = "" +
type presenceStatements struct {
db *sql.DB
streamIDStatements *streamIDStatements
streamIDStatements *StreamIDStatements
upsertPresenceStmt *sql.Stmt
upsertPresenceFromSyncStmt *sql.Stmt
selectPresenceForUsersStmt *sql.Stmt
@ -83,7 +83,7 @@ type presenceStatements struct {
selectPresenceAfterStmt *sql.Stmt
}
func NewSqlitePresenceTable(db *sql.DB, streamID *streamIDStatements) (*presenceStatements, error) {
func NewSqlitePresenceTable(db *sql.DB, streamID *StreamIDStatements) (*presenceStatements, error) {
_, err := db.Exec(presenceSchema)
if err != nil {
return nil, err

View file

@ -59,13 +59,13 @@ const selectMaxReceiptIDSQL = "" +
type receiptStatements struct {
db *sql.DB
streamIDStatements *streamIDStatements
streamIDStatements *StreamIDStatements
upsertReceipt *sql.Stmt
selectRoomReceipts *sql.Stmt
selectMaxReceiptID *sql.Stmt
}
func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Receipts, error) {
func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Receipts, error) {
_, err := db.Exec(receiptsSchema)
if err != nil {
return nil, err

View file

@ -32,12 +32,12 @@ const increaseStreamIDStmt = "" +
"UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" +
" RETURNING stream_id"
type streamIDStatements struct {
type StreamIDStatements struct {
db *sql.DB
increaseStreamIDStmt *sql.Stmt
}
func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
func (s *StreamIDStatements) Prepare(db *sql.DB) (err error) {
s.db = db
_, err = db.Exec(streamIDTableSchema)
if err != nil {
@ -49,31 +49,31 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
return
}
func (s *streamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
func (s *StreamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "global").Scan(&pos)
return
}
func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
func (s *StreamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "receipt").Scan(&pos)
return
}
func (s *streamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
func (s *StreamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "invite").Scan(&pos)
return
}
func (s *streamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
func (s *StreamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "accountdata").Scan(&pos)
return
}
func (s *streamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
func (s *StreamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "presence").Scan(&pos)
return

View file

@ -30,7 +30,7 @@ type SyncServerDatasource struct {
shared.Database
db *sql.DB
writer sqlutil.Writer
streamID streamIDStatements
streamID StreamIDStatements
}
// NewDatabase creates a new sync server database
@ -49,7 +49,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
}
func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) {
if err = d.streamID.prepare(d.db); err != nil {
if err = d.streamID.Prepare(d.db); err != nil {
return err
}
accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID)

View file

@ -1,121 +1,29 @@
package storage_test
// TODO: Fix these tests
/*
import (
"context"
"crypto/ed25519"
"encoding/json"
"fmt"
"os"
"reflect"
"testing"
"time"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
)
var (
ctx = context.Background()
emptyStateKey = ""
testOrigin = gomatrixserverlib.ServerName("hollow.knight")
testRoomID = fmt.Sprintf("!hallownest:%s", testOrigin)
testUserIDA = fmt.Sprintf("@hornet:%s", testOrigin)
testUserIDB = fmt.Sprintf("@paleking:%s", testOrigin)
testUserDeviceA = userapi.Device{
UserID: testUserIDA,
ID: "device_id_A",
DisplayName: "Device A",
}
testRoomVersion = gomatrixserverlib.RoomVersionV4
testKeyID = gomatrixserverlib.KeyID("ed25519:storage_test")
testPrivateKey = ed25519.NewKeyFromSeed([]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
})
)
var ctx = context.Background()
func MustCreateEvent(t *testing.T, roomID string, prevs []*gomatrixserverlib.HeaderedEvent, b *gomatrixserverlib.EventBuilder) *gomatrixserverlib.HeaderedEvent {
b.RoomID = roomID
if prevs != nil {
prevIDs := make([]string, len(prevs))
for i := range prevs {
prevIDs[i] = prevs[i].EventID()
}
b.PrevEvents = prevIDs
}
e, err := b.Build(time.Now(), testOrigin, testKeyID, testPrivateKey, testRoomVersion)
if err != nil {
t.Fatalf("failed to build event: %s", err)
}
return e.Headered(testRoomVersion)
}
func MustCreateDatabase(t *testing.T) storage.Database {
dbname := fmt.Sprintf("test_%s.db", t.Name())
if _, err := os.Stat(dbname); err == nil {
if err = os.Remove(dbname); err != nil {
t.Fatalf("tried to delete stale test database but failed: %s", err)
}
}
db, err := sqlite3.NewDatabase(&config.DatabaseOptions{
ConnectionString: config.DataSource(fmt.Sprintf("file:%s", dbname)),
func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := storage.NewSyncServerDatasource(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
})
if err != nil {
t.Fatalf("NewSyncServerDatasource returned %s", err)
}
return db
}
// Create a list of events which include a create event, join event and some messages.
func SimpleRoom(t *testing.T, roomID, userA, userB string) (msgs []*gomatrixserverlib.HeaderedEvent, state []*gomatrixserverlib.HeaderedEvent) {
var events []*gomatrixserverlib.HeaderedEvent
events = append(events, MustCreateEvent(t, roomID, nil, &gomatrixserverlib.EventBuilder{
Content: []byte(fmt.Sprintf(`{"room_version":"4","creator":"%s"}`, userA)),
Type: "m.room.create",
StateKey: &emptyStateKey,
Sender: userA,
Depth: int64(len(events) + 1),
}))
state = append(state, events[len(events)-1])
events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{
Content: []byte(`{"membership":"join"}`),
Type: "m.room.member",
StateKey: &userA,
Sender: userA,
Depth: int64(len(events) + 1),
}))
state = append(state, events[len(events)-1])
for i := 0; i < 10; i++ {
events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{
Content: []byte(fmt.Sprintf(`{"body":"Message A %d"}`, i+1)),
Type: "m.room.message",
Sender: userA,
Depth: int64(len(events) + 1),
}))
}
events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{
Content: []byte(`{"membership":"join"}`),
Type: "m.room.member",
StateKey: &userB,
Sender: userB,
Depth: int64(len(events) + 1),
}))
state = append(state, events[len(events)-1])
for i := 0; i < 10; i++ {
events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{
Content: []byte(fmt.Sprintf(`{"body":"Message B %d"}`, i+1)),
Type: "m.room.message",
Sender: userB,
Depth: int64(len(events) + 1),
}))
}
return events, state
return db, close
}
func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (positions []types.StreamPosition) {
@ -131,206 +39,158 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
if err != nil {
t.Fatalf("WriteEvent failed: %s", err)
}
fmt.Println("Event ID", ev.EventID(), " spos=", pos, "depth=", ev.Depth())
t.Logf("Event ID %s spos=%v depth=%v", ev.EventID(), pos, ev.Depth())
positions = append(positions, pos)
}
return
}
func TestWriteEvents(t *testing.T) {
t.Parallel()
db := MustCreateDatabase(t)
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
MustWriteEvents(t, db, events)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
alice := test.NewUser()
r := test.NewRoom(t, alice)
db, close := MustCreateDatabase(t, dbType)
defer close()
MustWriteEvents(t, db, r.Events())
})
}
// These tests assert basic functionality of the IncrementalSync and CompleteSync functions.
func TestSyncResponse(t *testing.T) {
t.Parallel()
db := MustCreateDatabase(t)
events, state := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
// These tests assert basic functionality of RecentEvents for PDUs
func TestRecentEventsPDU(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := MustCreateDatabase(t, dbType)
defer close()
alice := test.NewUser()
// dummy room to make sure SQL queries are filtering on room ID
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
// actual test room
r := test.NewRoom(t, alice)
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hi"})
events := r.Events()
positions := MustWriteEvents(t, db, events)
latest, err := db.SyncPosition(ctx)
// dummy room to make sure SQL queries are filtering on room ID
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
latest, err := db.MaxStreamPositionForPDUs(ctx)
if err != nil {
t.Fatalf("failed to get SyncPosition: %s", err)
t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err)
}
testCases := []struct {
Name string
DoSync func() (*types.Response, error)
WantTimeline []*gomatrixserverlib.HeaderedEvent
WantState []*gomatrixserverlib.HeaderedEvent
From types.StreamPosition
To types.StreamPosition
Limit int
ReverseOrder bool
WantEvents []*gomatrixserverlib.HeaderedEvent
WantLimited bool
}{
// The purpose of this test is to make sure that incremental syncs are including up to the latest events.
// It's a basic sanity test that sync works. It creates a `since` token that is on the penultimate event.
// It's a basic sanity test that sync works. It creates a streaming position that is on the penultimate event.
// It makes sure the response includes the final event.
{
Name: "IncrementalSync penultimate",
DoSync: func() (*types.Response, error) {
from := types.StreamingToken{ // pretend we are at the penultimate event
PDUPosition: positions[len(positions)-2],
}
res := types.NewResponse()
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
Name: "penultimate",
From: positions[len(positions)-2], // pretend we are at the penultimate event
To: latest,
Limit: 100,
WantEvents: events[len(events)-1:],
WantLimited: false,
},
WantTimeline: events[len(events)-1:],
},
// The purpose of this test is to check that passing a `numRecentEventsPerRoom` correctly limits the
// number of returned events. This is critical for big rooms hence the test here.
// The purpose of this test is to check that limits can be applied and work.
// This is critical for big rooms hence the test here.
{
Name: "IncrementalSync limited",
DoSync: func() (*types.Response, error) {
from := types.StreamingToken{ // pretend we are 10 events behind
PDUPosition: positions[len(positions)-11],
}
res := types.NewResponse()
// limit is set to 5
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
Name: "limited",
From: 0,
To: latest,
Limit: 1,
WantEvents: events[len(events)-1:],
WantLimited: true,
},
// want the last 5 events, NOT the last 10.
WantTimeline: events[len(events)-5:],
},
// The purpose of this test is to check that CompleteSync returns all the current state as well as
// honouring the `numRecentEventsPerRoom` value
// The purpose of this test is to check that we can return every event with a high
// enough limit
{
Name: "CompleteSync limited",
DoSync: func() (*types.Response, error) {
res := types.NewResponse()
// limit set to 5
return db.CompleteSync(ctx, res, testUserDeviceA, 5)
Name: "large limited",
From: 0,
To: latest,
Limit: 100,
WantEvents: events,
WantLimited: false,
},
// want the last 5 events
WantTimeline: events[len(events)-5:],
// want all state for the room
WantState: state,
},
// The purpose of this test is to check that CompleteSync can return everything with a high enough
// `numRecentEventsPerRoom`.
// The purpose of this test is to check that we can return events in reverse order
{
Name: "CompleteSync",
DoSync: func() (*types.Response, error) {
res := types.NewResponse()
return db.CompleteSync(ctx, res, testUserDeviceA, len(events)+1)
},
WantTimeline: events,
// We want no state at all as that field in /sync is the delta between the token (beginning of time)
// and the START of the timeline.
Name: "reverse",
From: positions[len(positions)-3], // 2 events back
To: latest,
Limit: 100,
ReverseOrder: true,
WantEvents: test.Reversed(events[len(events)-2:]),
WantLimited: false,
},
}
for _, tc := range testCases {
for i := range testCases {
tc := testCases[i]
t.Run(tc.Name, func(st *testing.T) {
res, err := tc.DoSync()
var filter gomatrixserverlib.RoomEventFilter
filter.Limit = tc.Limit
gotEvents, limited, err := db.RecentEvents(ctx, r.ID, types.Range{
From: tc.From,
To: tc.To,
}, &filter, !tc.ReverseOrder, true)
if err != nil {
st.Fatalf("failed to do sync: %s", err)
}
next := types.StreamingToken{
PDUPosition: latest.PDUPosition,
TypingPosition: latest.TypingPosition,
ReceiptPosition: latest.ReceiptPosition,
SendToDevicePosition: latest.SendToDevicePosition,
if limited != tc.WantLimited {
st.Errorf("got limited=%v want %v", limited, tc.WantLimited)
}
if res.NextBatch.String() != next.String() {
st.Errorf("NextBatch got %s want %s", res.NextBatch, next.String())
if len(gotEvents) != len(tc.WantEvents) {
st.Errorf("got %d events, want %d", len(gotEvents), len(tc.WantEvents))
}
for j := range gotEvents {
if !reflect.DeepEqual(gotEvents[j].JSON(), tc.WantEvents[j].JSON()) {
st.Errorf("event %d got %s want %s", j, string(gotEvents[j].JSON()), string(tc.WantEvents[j].JSON()))
}
roomRes, ok := res.Rooms.Join[testRoomID]
if !ok {
st.Fatalf("IncrementalSync response missing room %s - response: %+v", testRoomID, res)
}
assertEventsEqual(st, "state for "+testRoomID, false, roomRes.State.Events, tc.WantState)
assertEventsEqual(st, "timeline for "+testRoomID, false, roomRes.Timeline.Events, tc.WantTimeline)
})
}
}
func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
t.Parallel()
db := MustCreateDatabase(t)
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
positions := MustWriteEvents(t, db, events)
latest, err := db.SyncPosition(ctx)
if err != nil {
t.Fatalf("failed to get SyncPosition: %s", err)
}
from := types.StreamingToken{
PDUPosition: positions[len(positions)-2],
}
res := types.NewResponse()
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
if err != nil {
t.Fatalf("failed to IncrementalSync with latest token")
}
roomRes, ok := res.Rooms.Join[testRoomID]
if !ok {
t.Fatalf("IncrementalSync response missing room %s - response: %+v", testRoomID, res)
}
// returns the last event "Message 10"
assertEventsEqual(t, "IncrementalSync Timeline", false, roomRes.Timeline.Events, reversed(events[len(events)-1:]))
prev := roomRes.Timeline.PrevBatch.String()
if prev == "" {
t.Fatalf("IncrementalSync expected prev_batch token")
}
prevBatchToken, err := types.NewTopologyTokenFromString(prev)
if err != nil {
t.Fatalf("failed to NewTopologyTokenFromString : %s", err)
}
// backpaginate 5 messages starting at the latest position.
// head towards the beginning of time
to := types.TopologyToken{}
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &prevBatchToken, &to, testRoomID, 5, true)
if err != nil {
t.Fatalf("GetEventsInRange returned an error: %s", err)
}
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-6:len(events)-1]))
}
// The purpose of this test is to ensure that backfill does indeed go backwards, using a stream token.
func TestGetEventsInRangeWithStreamToken(t *testing.T) {
t.Parallel()
db := MustCreateDatabase(t)
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
MustWriteEvents(t, db, events)
latest, err := db.SyncPosition(ctx)
if err != nil {
t.Fatalf("failed to get SyncPosition: %s", err)
}
// head towards the beginning of time
to := types.StreamingToken{}
// backpaginate 5 messages starting at the latest position.
paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true)
if err != nil {
t.Fatalf("GetEventsInRange returned an error: %s", err)
}
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:]))
})
}
// The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token
func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
t.Parallel()
db := MustCreateDatabase(t)
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
MustWriteEvents(t, db, events)
from, err := db.MaxTopologicalPosition(ctx, testRoomID)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := MustCreateDatabase(t, dbType)
defer close()
alice := test.NewUser()
r := test.NewRoom(t, alice)
for i := 0; i < 10; i++ {
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)})
}
events := r.Events()
_ = MustWriteEvents(t, db, events)
from, err := db.MaxTopologicalPosition(ctx, r.ID)
if err != nil {
t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
}
t.Logf("max topo pos = %+v", from)
// head towards the beginning of time
to := types.TopologyToken{}
// backpaginate 5 messages starting at the latest position.
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, testRoomID, 5, true)
filter := &gomatrixserverlib.RoomEventFilter{Limit: 5}
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true)
if err != nil {
t.Fatalf("GetEventsInRange returned an error: %s", err)
t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
}
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:]))
gots := db.StreamEventsToEvents(nil, paginatedEvents)
test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
})
}
/*
// The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth.
// For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent
// will appear FIRST when going backwards. This test creates a DAG like:
@ -740,12 +600,4 @@ func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *typ
tok.Decrement()
return &tok
}
func reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
for i := 0; i < len(in); i++ {
out[i] = in[len(in)-i-1]
}
return out
}
*/

View file

@ -59,7 +59,7 @@ type Events interface {
SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
// SelectEarlyEvents returns the earliest events in the given room.
SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error)
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error)
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool) ([]types.StreamEvent, error)
UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error
// DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely.
DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
@ -84,8 +84,6 @@ type Topology interface {
SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error)
// SelectMaxPositionInTopology returns the event which has the highest depth, and if there are multiple, the event with the highest stream position.
SelectMaxPositionInTopology(ctx context.Context, txn *sql.Tx, roomID string) (depth types.StreamPosition, spos types.StreamPosition, err error)
// DeleteTopologyForRoom removes all topological information for a room. This should only be done when removing the room entirely.
DeleteTopologyForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
// SelectStreamToTopologicalPosition converts a stream position to a topological position by finding the nearest topological position in the room.
SelectStreamToTopologicalPosition(ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, forward bool) (topoPos types.StreamPosition, err error)
}
@ -132,8 +130,6 @@ type BackwardsExtremities interface {
SelectBackwardExtremitiesForRoom(ctx context.Context, roomID string) (bwExtrems map[string][]string, err error)
// DeleteBackwardExtremity removes a backwards extremity for a room, if one existed.
DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error)
// DeleteBackwardExtremitiesFoorRoomID removes all backward extremities for a room. This should only be done when removing the room entirely.
DeleteBackwardExtremitiesForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
}
// SendToDevice tracks send-to-device messages which are sent to individual
@ -173,7 +169,6 @@ type Receipts interface {
type Memberships interface {
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string) (eventID string, streamPos, topologyPos types.StreamPosition, err error)
SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error)
}

View file

@ -0,0 +1,105 @@
package tables_test
import (
"context"
"database/sql"
"fmt"
"reflect"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
)
func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events, *sql.DB, func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
})
if err != nil {
t.Fatalf("failed to open db: %s", err)
}
var tab tables.Events
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresEventsTable(db)
case test.DBTypeSQLite:
var stream sqlite3.StreamIDStatements
if err = stream.Prepare(db); err != nil {
t.Fatalf("failed to prepare stream stmts: %s", err)
}
tab, err = sqlite3.NewSqliteEventsTable(db, &stream)
}
if err != nil {
t.Fatalf("failed to make new table: %s", err)
}
return tab, db, close
}
func TestOutputRoomEventsTable(t *testing.T) {
ctx := context.Background()
alice := test.NewUser()
room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, db, close := newOutputRoomEventsTable(t, dbType)
defer close()
events := room.Events()
err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
for _, ev := range events {
_, err := tab.InsertEvent(ctx, txn, ev, nil, nil, nil, false)
if err != nil {
return fmt.Errorf("failed to InsertEvent: %s", err)
}
}
// order = 2,0,3,1
wantEventIDs := []string{
events[2].EventID(), events[0].EventID(), events[3].EventID(), events[1].EventID(),
}
gotEvents, err := tab.SelectEvents(ctx, txn, wantEventIDs, nil, true)
if err != nil {
return fmt.Errorf("failed to SelectEvents: %s", err)
}
gotEventIDs := make([]string, len(gotEvents))
for i := range gotEvents {
gotEventIDs[i] = gotEvents[i].EventID()
}
if !reflect.DeepEqual(gotEventIDs, wantEventIDs) {
return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventIDs)
}
// Test that contains_url is correctly populated
urlEv := room.CreateEvent(t, alice, "m.text", map[string]interface{}{
"body": "test.txt",
"url": "mxc://test.txt",
})
if _, err = tab.InsertEvent(ctx, txn, urlEv, nil, nil, nil, false); err != nil {
return fmt.Errorf("failed to InsertEvent: %s", err)
}
wantEventID := []string{urlEv.EventID()}
t := true
gotEvents, err = tab.SelectEvents(ctx, txn, wantEventID, &gomatrixserverlib.RoomEventFilter{Limit: 1, ContainsURL: &t}, true)
if err != nil {
return fmt.Errorf("failed to SelectEvents: %s", err)
}
gotEventIDs = make([]string, len(gotEvents))
for i := range gotEvents {
gotEventIDs[i] = gotEvents[i].EventID()
}
if !reflect.DeepEqual(gotEventIDs, wantEventID) {
return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventID)
}
return nil
})
if err != nil {
t.Fatalf("err: %s", err)
}
})
}

View file

@ -0,0 +1,91 @@
package tables_test
import (
"context"
"database/sql"
"fmt"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test"
)
func newTopologyTable(t *testing.T, dbType test.DBType) (tables.Topology, *sql.DB, func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
})
if err != nil {
t.Fatalf("failed to open db: %s", err)
}
var tab tables.Topology
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresTopologyTable(db)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSqliteTopologyTable(db)
}
if err != nil {
t.Fatalf("failed to make new table: %s", err)
}
return tab, db, close
}
func TestTopologyTable(t *testing.T) {
ctx := context.Background()
alice := test.NewUser()
room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, db, close := newTopologyTable(t, dbType)
defer close()
events := room.Events()
err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
var highestPos types.StreamPosition
for i, ev := range events {
topoPos, err := tab.InsertEventInTopology(ctx, txn, ev, types.StreamPosition(i))
if err != nil {
return fmt.Errorf("failed to InsertEventInTopology: %s", err)
}
// topo pos = depth, depth starts at 1, hence 1+i
if topoPos != types.StreamPosition(1+i) {
return fmt.Errorf("got topo pos %d want %d", topoPos, 1+i)
}
highestPos = topoPos + 1
}
// check ordering works without limit
eventIDs, err := tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, true)
if err != nil {
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
}
test.AssertEventIDsEqual(t, eventIDs, events[:])
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, false)
if err != nil {
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
}
test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[:]))
// check ordering works with limit
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, true)
if err != nil {
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
}
test.AssertEventIDsEqual(t, eventIDs, events[:3])
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, false)
if err != nil {
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
}
test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[len(events)-3:]))
return nil
})
if err != nil {
t.Fatalf("err: %s", err)
}
})
}

View file

@ -6,6 +6,7 @@ import (
"sync"
"time"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
@ -26,7 +27,8 @@ type PDUStreamProvider struct {
tasks chan func()
workers atomic.Int32
userAPI userapi.UserInternalAPI
// userID+deviceID -> lazy loading cache
lazyLoadCache *caching.LazyLoadCache
}
func (p *PDUStreamProvider) worker() {
@ -188,7 +190,7 @@ func (p *PDUStreamProvider) IncrementalSync(
newPos = from
for _, delta := range stateDeltas {
var pos types.StreamPosition
if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, r, delta, &eventFilter, req.Response); err != nil {
if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, r, delta, &eventFilter, &stateFilter, req.Response); err != nil {
req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed")
return to
}
@ -209,6 +211,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
r types.Range,
delta types.StateDelta,
eventFilter *gomatrixserverlib.RoomEventFilter,
stateFilter *gomatrixserverlib.StateFilter,
res *types.Response,
) (types.StreamPosition, error) {
if delta.MembershipPos > 0 && delta.Membership == gomatrixserverlib.Leave {
@ -247,7 +250,8 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
// room that were returned.
latestPosition := r.To
updateLatestPosition := func(mostRecentEventID string) {
if _, pos, err := p.DB.PositionInTopology(ctx, mostRecentEventID); err == nil {
var pos types.StreamPosition
if _, pos, err = p.DB.PositionInTopology(ctx, mostRecentEventID); err == nil {
switch {
case r.Backwards && pos > latestPosition:
fallthrough
@ -263,6 +267,19 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID())
}
if stateFilter.LazyLoadMembers {
if err != nil {
return r.From, err
}
delta.StateEvents, err = p.lazyLoadMembers(
ctx, delta.RoomID, true, limited, stateFilter.IncludeRedundantMembers,
device, recentEvents, delta.StateEvents,
)
if err != nil {
return r.From, err
}
}
hasMembershipChange := false
for _, recentEvent := range recentStreamEvents {
if recentEvent.Type() == gomatrixserverlib.MRoomMember && recentEvent.StateKey() != nil {
@ -402,6 +419,20 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
// "Can sync a room with a message with a transaction id" - which does a complete sync to check.
recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents)
if stateFilter.LazyLoadMembers {
if err != nil {
return nil, err
}
stateEvents, err = p.lazyLoadMembers(ctx, roomID,
false, limited, stateFilter.IncludeRedundantMembers,
device, recentEvents, stateEvents,
)
if err != nil {
return nil, err
}
}
jr = types.NewJoinResponse()
jr.Summary.JoinedMemberCount = &joinedCount
jr.Summary.InvitedMemberCount = &invitedCount
@ -412,6 +443,69 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
return jr, nil
}
func (p *PDUStreamProvider) lazyLoadMembers(
ctx context.Context, roomID string,
incremental, limited, includeRedundant bool,
device *userapi.Device,
timelineEvents, stateEvents []*gomatrixserverlib.HeaderedEvent,
) ([]*gomatrixserverlib.HeaderedEvent, error) {
if len(timelineEvents) == 0 {
return stateEvents, nil
}
// Work out which memberships to include
timelineUsers := make(map[string]struct{})
if !incremental {
timelineUsers[device.UserID] = struct{}{}
}
// Add all users the client doesn't know about yet to a list
for _, event := range timelineEvents {
// Membership is not yet cached, add it to the list
if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, event.Sender()); !ok {
timelineUsers[event.Sender()] = struct{}{}
}
}
// Preallocate with the same amount, even if it will end up with fewer values
newStateEvents := make([]*gomatrixserverlib.HeaderedEvent, 0, len(stateEvents))
// Remove existing membership events we don't care about, e.g. users not in the timeline.events
for _, event := range stateEvents {
if event.Type() == gomatrixserverlib.MRoomMember && event.StateKey() != nil {
// If this is a gapped incremental sync, we still want this membership
isGappedIncremental := limited && incremental
// We want this users membership event, keep it in the list
_, ok := timelineUsers[event.Sender()]
wantMembership := ok || isGappedIncremental
if wantMembership {
newStateEvents = append(newStateEvents, event)
if !includeRedundant {
p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, event.Sender(), event.EventID())
}
delete(timelineUsers, event.Sender())
}
} else {
newStateEvents = append(newStateEvents, event)
}
}
wantUsers := make([]string, 0, len(timelineUsers))
for userID := range timelineUsers {
wantUsers = append(wantUsers, userID)
}
// Query missing membership events
memberships, err := p.DB.GetStateEventsForRoom(ctx, roomID, &gomatrixserverlib.StateFilter{
Limit: 100,
Senders: &wantUsers,
Types: &[]string{gomatrixserverlib.MRoomMember},
})
if err != nil {
return stateEvents, err
}
// cache the membership events
for _, membership := range memberships {
p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, membership.Sender(), membership.EventID())
}
stateEvents = append(newStateEvents, memberships...)
return stateEvents, nil
}
// addIgnoredUsersToFilter adds ignored users to the eventfilter and
// the syncreq itself for further use in streams.
func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error {
@ -423,8 +517,12 @@ func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *ty
return err
}
req.IgnoredUsers = *ignores
userList := make([]string, 0, len(ignores.List))
for userID := range ignores.List {
eventFilter.NotSenders = append(eventFilter.NotSenders, userID)
userList = append(userList, userID)
}
if len(userList) > 0 {
eventFilter.NotSenders = &userList
}
return nil
}

View file

@ -27,12 +27,12 @@ type Streams struct {
func NewSyncStreamProviders(
d storage.Database, userAPI userapi.UserInternalAPI,
rsAPI rsapi.RoomserverInternalAPI, keyAPI keyapi.KeyInternalAPI,
eduCache *caching.EDUCache, notifier *notifier.Notifier,
eduCache *caching.EDUCache, lazyLoadCache *caching.LazyLoadCache, notifier *notifier.Notifier,
) *Streams {
streams := &Streams{
PDUStreamProvider: &PDUStreamProvider{
StreamProvider: StreamProvider{DB: d},
userAPI: userAPI,
lazyLoadCache: lazyLoadCache,
},
TypingStreamProvider: &TypingStreamProvider{
StreamProvider: StreamProvider{DB: d},

View file

@ -15,6 +15,7 @@
package sync
import (
"database/sql"
"encoding/json"
"fmt"
"net/http"
@ -60,10 +61,10 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return nil, fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err != nil {
if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err != nil && err != sql.ErrNoRows {
util.GetLogger(req.Context()).WithError(err).Error("syncDB.GetFilter failed")
return nil, fmt.Errorf("syncDB.GetFilter: %w", err)
} else {
} else if f != nil {
filter = *f
}
}

View file

@ -67,6 +67,9 @@ func NewRequestPool(
streams *streams.Streams, notifier *notifier.Notifier,
producer PresencePublisher,
) *RequestPool {
prometheus.MustRegister(
activeSyncRequests, waitingSyncRequests,
)
rp := &RequestPool{
db: db,
cfg: cfg,
@ -183,12 +186,6 @@ func (rp *RequestPool) updateLastSeen(req *http.Request, device *userapi.Device)
rp.lastseen.Store(device.UserID+device.ID, time.Now())
}
func init() {
prometheus.MustRegister(
activeSyncRequests, waitingSyncRequests,
)
}
var activeSyncRequests = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "dendrite",

View file

@ -57,8 +57,12 @@ func AddPublicRoutes(
}
eduCache := caching.NewTypingCache()
lazyLoadCache, err := caching.NewLazyLoadCache()
if err != nil {
logrus.WithError(err).Panicf("failed to create lazy loading cache")
}
notifier := notifier.NewNotifier()
streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, notifier)
streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, lazyLoadCache, notifier)
notifier.SetCurrentPosition(streams.Latest(context.Background()))
if err = notifier.Load(context.Background(), syncDB); err != nil {
logrus.WithError(err).Panicf("failed to load notifier ")

View file

@ -697,3 +697,16 @@ Room state after a rejected state event is the same as before
Ignore user in existing room
Ignore invite in full sync
Ignore invite in incremental sync
A filtered timeline reaches its limit
A change to displayname should not result in a full state sync
Can fetch images in room
The only membership state included in an initial sync is for all the senders in the timeline
The only membership state included in an incremental sync is for senders in the timeline
Old members are included in gappy incr LL sync if they start speaking
We do send redundant membership state across incremental syncs if asked
Rejecting invite over federation doesn't break incremental /sync
Gapped incremental syncs include all state changes
Old leaves are present in gapped incremental syncs
Leaves are present in non-gapped incremental syncs
Members from the gap are included in gappy incr LL sync
Presence can be set from sync

170
test/db.go Normal file
View file

@ -0,0 +1,170 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package test
import (
"crypto/sha256"
"database/sql"
"encoding/hex"
"fmt"
"os"
"os/exec"
"os/user"
"testing"
"github.com/lib/pq"
)
type DBType int
var DBTypeSQLite DBType = 1
var DBTypePostgres DBType = 2
var Quiet = false
func createLocalDB(dbName string) {
if !Quiet {
fmt.Println("Note: tests require a postgres install accessible to the current user")
}
createDB := exec.Command("createdb", dbName)
if !Quiet {
createDB.Stdout = os.Stdout
createDB.Stderr = os.Stderr
}
err := createDB.Run()
if err != nil && !Quiet {
fmt.Println("createLocalDB returned error:", err)
}
}
func createRemoteDB(t *testing.T, dbName, user, connStr string) {
db, err := sql.Open("postgres", connStr+" dbname=postgres")
if err != nil {
t.Fatalf("failed to open postgres conn with connstr=%s : %s", connStr, err)
}
_, err = db.Exec(fmt.Sprintf(`CREATE DATABASE %s;`, dbName))
if err != nil {
pqErr, ok := err.(*pq.Error)
if !ok {
t.Fatalf("failed to CREATE DATABASE: %s", err)
}
// we ignore duplicate database error as we expect this
if pqErr.Code != "42P04" {
t.Fatalf("failed to CREATE DATABASE with code=%s msg=%s", pqErr.Code, pqErr.Message)
}
}
_, err = db.Exec(fmt.Sprintf(`GRANT ALL PRIVILEGES ON DATABASE %s TO %s`, dbName, user))
if err != nil {
t.Fatalf("failed to GRANT: %s", err)
}
_ = db.Close()
}
func currentUser() string {
user, err := user.Current()
if err != nil {
if !Quiet {
fmt.Println("cannot get current user: ", err)
}
os.Exit(2)
}
return user.Username
}
// Prepare a sqlite or postgres connection string for testing.
// Returns the connection string to use and a close function which must be called when the test finishes.
// Calling this function twice will return the same database, which will have data from previous tests
// unless close() is called.
// TODO: namespace for concurrent package tests
func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, close func()) {
if dbType == DBTypeSQLite {
// this will be made in the current working directory which namespaces concurrent package runs correctly
dbname := "dendrite_test.db"
return fmt.Sprintf("file:%s", dbname), func() {
err := os.Remove(dbname)
if err != nil {
t.Fatalf("failed to cleanup sqlite db '%s': %s", dbname, err)
}
}
}
// Required vars: user and db
// We'll try to infer from the local env if they are missing
user := os.Getenv("POSTGRES_USER")
if user == "" {
user = currentUser()
}
connStr = fmt.Sprintf(
"user=%s sslmode=disable",
user,
)
// optional vars, used in CI
password := os.Getenv("POSTGRES_PASSWORD")
if password != "" {
connStr += fmt.Sprintf(" password=%s", password)
}
host := os.Getenv("POSTGRES_HOST")
if host != "" {
connStr += fmt.Sprintf(" host=%s", host)
}
// superuser database
postgresDB := os.Getenv("POSTGRES_DB")
// we cannot use 'dendrite_test' here else 2x concurrently running packages will try to use the same db.
// instead, hash the current working directory, snaffle the first 16 bytes and append that to dendrite_test
// and use that as the unique db name. We do this because packages are per-directory hence by hashing the
// working (test) directory we ensure we get a consistent hash and don't hash against concurrent packages.
wd, err := os.Getwd()
if err != nil {
t.Fatalf("cannot get working directory: %s", err)
}
hash := sha256.Sum256([]byte(wd))
dbName := fmt.Sprintf("dendrite_test_%s", hex.EncodeToString(hash[:16]))
if postgresDB == "" { // local server, use createdb
createLocalDB(dbName)
} else { // remote server, shell into the postgres user and CREATE DATABASE
createRemoteDB(t, dbName, user, connStr)
}
connStr += fmt.Sprintf(" dbname=%s", dbName)
return connStr, func() {
// Drop all tables on the database to get a fresh instance
db, err := sql.Open("postgres", connStr)
if err != nil {
t.Fatalf("failed to connect to postgres db '%s': %s", connStr, err)
}
_, err = db.Exec(`DROP SCHEMA public CASCADE;
CREATE SCHEMA public;`)
if err != nil {
t.Fatalf("failed to cleanup postgres db '%s': %s", connStr, err)
}
_ = db.Close()
}
}
// Creates subtests with each known DBType
func WithAllDatabases(t *testing.T, testFn func(t *testing.T, db DBType)) {
dbs := map[string]DBType{
"postgres": DBTypePostgres,
"sqlite": DBTypeSQLite,
}
for dbName, dbType := range dbs {
dbt := dbType
t.Run(dbName, func(tt *testing.T) {
tt.Parallel()
testFn(tt, dbt)
})
}
}

90
test/event.go Normal file
View file

@ -0,0 +1,90 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package test
import (
"bytes"
"crypto/ed25519"
"testing"
"time"
"github.com/matrix-org/gomatrixserverlib"
)
type eventMods struct {
originServerTS time.Time
origin gomatrixserverlib.ServerName
stateKey *string
unsigned interface{}
keyID gomatrixserverlib.KeyID
privKey ed25519.PrivateKey
}
type eventModifier func(e *eventMods)
func WithTimestamp(ts time.Time) eventModifier {
return func(e *eventMods) {
e.originServerTS = ts
}
}
func WithStateKey(skey string) eventModifier {
return func(e *eventMods) {
e.stateKey = &skey
}
}
func WithUnsigned(unsigned interface{}) eventModifier {
return func(e *eventMods) {
e.unsigned = unsigned
}
}
// Reverse a list of events
func Reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
for i := 0; i < len(in); i++ {
out[i] = in[len(in)-i-1]
}
return out
}
func AssertEventIDsEqual(t *testing.T, gotEventIDs []string, wants []*gomatrixserverlib.HeaderedEvent) {
t.Helper()
if len(gotEventIDs) != len(wants) {
t.Fatalf("length mismatch: got %d events, want %d", len(gotEventIDs), len(wants))
}
for i := range wants {
w := wants[i].EventID()
g := gotEventIDs[i]
if w != g {
t.Errorf("event at index %d mismatch:\ngot %s\n\nwant %s", i, string(g), string(w))
}
}
}
func AssertEventsEqual(t *testing.T, gots, wants []*gomatrixserverlib.HeaderedEvent) {
t.Helper()
if len(gots) != len(wants) {
t.Fatalf("length mismatch: got %d events, want %d", len(gots), len(wants))
}
for i := range wants {
w := wants[i].JSON()
g := gots[i].JSON()
if !bytes.Equal(w, g) {
t.Errorf("event at index %d mismatch:\ngot %s\n\nwant %s", i, string(g), string(w))
}
}
}

223
test/room.go Normal file
View file

@ -0,0 +1,223 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package test
import (
"crypto/ed25519"
"encoding/json"
"fmt"
"sync/atomic"
"testing"
"time"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/gomatrixserverlib"
)
type Preset int
var (
PresetNone Preset = 0
PresetPrivateChat Preset = 1
PresetPublicChat Preset = 2
PresetTrustedPrivateChat Preset = 3
roomIDCounter = int64(0)
testKeyID = gomatrixserverlib.KeyID("ed25519:test")
testPrivateKey = ed25519.NewKeyFromSeed([]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
})
)
type Room struct {
ID string
Version gomatrixserverlib.RoomVersion
preset Preset
creator *User
authEvents gomatrixserverlib.AuthEvents
events []*gomatrixserverlib.HeaderedEvent
}
// Create a new test room. Automatically creates the initial create events.
func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room {
t.Helper()
counter := atomic.AddInt64(&roomIDCounter, 1)
// set defaults then let roomModifiers override
r := &Room{
ID: fmt.Sprintf("!%d:localhost", counter),
creator: creator,
authEvents: gomatrixserverlib.NewAuthEvents(nil),
preset: PresetPublicChat,
Version: gomatrixserverlib.RoomVersionV9,
}
for _, m := range modifiers {
m(t, r)
}
r.insertCreateEvents(t)
return r
}
func (r *Room) insertCreateEvents(t *testing.T) {
t.Helper()
var joinRule gomatrixserverlib.JoinRuleContent
var hisVis gomatrixserverlib.HistoryVisibilityContent
plContent := eventutil.InitialPowerLevelsContent(r.creator.ID)
switch r.preset {
case PresetTrustedPrivateChat:
fallthrough
case PresetPrivateChat:
joinRule.JoinRule = "invite"
hisVis.HistoryVisibility = "shared"
case PresetPublicChat:
joinRule.JoinRule = "public"
hisVis.HistoryVisibility = "shared"
}
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomCreate, map[string]interface{}{
"creator": r.creator.ID,
"room_version": r.Version,
}, WithStateKey(""))
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "join",
}, WithStateKey(r.creator.ID))
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey(""))
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey(""))
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey(""))
}
// Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe.
func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, content interface{}, mods ...eventModifier) *gomatrixserverlib.HeaderedEvent {
t.Helper()
depth := 1 + len(r.events) // depth starts at 1
// possible event modifiers (optional fields)
mod := &eventMods{}
for _, m := range mods {
m(mod)
}
if mod.privKey == nil {
mod.privKey = testPrivateKey
}
if mod.keyID == "" {
mod.keyID = testKeyID
}
if mod.originServerTS.IsZero() {
mod.originServerTS = time.Now()
}
if mod.origin == "" {
mod.origin = gomatrixserverlib.ServerName("localhost")
}
var unsigned gomatrixserverlib.RawJSON
var err error
if mod.unsigned != nil {
unsigned, err = json.Marshal(mod.unsigned)
if err != nil {
t.Fatalf("CreateEvent[%s]: failed to marshal unsigned field: %s", eventType, err)
}
}
builder := &gomatrixserverlib.EventBuilder{
Sender: creator.ID,
RoomID: r.ID,
Type: eventType,
StateKey: mod.stateKey,
Depth: int64(depth),
Unsigned: unsigned,
}
err = builder.SetContent(content)
if err != nil {
t.Fatalf("CreateEvent[%s]: failed to SetContent: %s", eventType, err)
}
if depth > 1 {
builder.PrevEvents = []gomatrixserverlib.EventReference{r.events[len(r.events)-1].EventReference()}
}
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)
if err != nil {
t.Fatalf("CreateEvent[%s]: failed to StateNeededForEventBuilder: %s", eventType, err)
}
refs, err := eventsNeeded.AuthEventReferences(&r.authEvents)
if err != nil {
t.Fatalf("CreateEvent[%s]: failed to AuthEventReferences: %s", eventType, err)
}
builder.AuthEvents = refs
ev, err := builder.Build(
mod.originServerTS, mod.origin, mod.keyID,
mod.privKey, r.Version,
)
if err != nil {
t.Fatalf("CreateEvent[%s]: failed to build event: %s", eventType, err)
}
if err = gomatrixserverlib.Allowed(ev, &r.authEvents); err != nil {
t.Fatalf("CreateEvent[%s]: failed to verify event was allowed: %s", eventType, err)
}
return ev.Headered(r.Version)
}
// Add a new event to this room DAG. Not thread-safe.
func (r *Room) InsertEvent(t *testing.T, he *gomatrixserverlib.HeaderedEvent) {
t.Helper()
// Add the event to the list of auth events
r.events = append(r.events, he)
if he.StateKey() != nil {
err := r.authEvents.AddEvent(he.Unwrap())
if err != nil {
t.Fatalf("InsertEvent: failed to add event to auth events: %s", err)
}
}
}
func (r *Room) Events() []*gomatrixserverlib.HeaderedEvent {
return r.events
}
func (r *Room) CreateAndInsert(t *testing.T, creator *User, eventType string, content interface{}, mods ...eventModifier) *gomatrixserverlib.HeaderedEvent {
t.Helper()
he := r.CreateEvent(t, creator, eventType, content, mods...)
r.InsertEvent(t, he)
return he
}
// All room modifiers are below
type roomModifier func(t *testing.T, r *Room)
func RoomPreset(p Preset) roomModifier {
return func(t *testing.T, r *Room) {
switch p {
case PresetPrivateChat:
fallthrough
case PresetPublicChat:
fallthrough
case PresetTrustedPrivateChat:
fallthrough
case PresetNone:
r.preset = p
default:
t.Errorf("invalid RoomPreset: %v", p)
}
}
}
func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier {
return func(t *testing.T, r *Room) {
r.Version = ver
}
}

View file

@ -1,5 +1,4 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -13,24 +12,25 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
package test
import (
"database/sql"
"fmt"
"sync/atomic"
)
type statements struct {
media mediaStatements
thumbnail thumbnailStatements
var (
userIDCounter = int64(0)
)
type User struct {
ID string
}
func (s *statements) prepare(db *sql.DB) (err error) {
if err = s.media.prepare(db); err != nil {
return
func NewUser() *User {
counter := atomic.AddInt64(&userIDCounter, 1)
u := &User{
ID: fmt.Sprintf("@%d:localhost", counter),
}
if err = s.thumbnail.prepare(db); err != nil {
return
}
return
return u
}

View file

@ -494,7 +494,7 @@ type PerformPusherDeletionRequest struct {
type Pusher struct {
SessionID int64 `json:"session_id,omitempty"`
PushKey string `json:"pushkey"`
PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"`
PushKeyTS int64 `json:"pushkey_ts,omitempty"`
Kind PusherKind `json:"kind"`
AppID string `json:"app_id"`
AppDisplayName string `json:"app_display_name"`

View file

@ -653,7 +653,7 @@ func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.Perform
return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart)
}
if req.Pusher.PushKeyTS == 0 {
req.Pusher.PushKeyTS = gomatrixserverlib.AsTimestamp(time.Now())
req.Pusher.PushKeyTS = int64(time.Now().Unix())
}
return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart)
}

View file

@ -369,6 +369,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryAccountAvailabilityPath,
httputil.MakeInternalAPI("queryAccountAvailability", func(req *http.Request) util.JSONResponse {
request := api.QueryAccountAvailabilityRequest{}
response := api.QueryAccountAvailabilityResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryAccountAvailability(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryAccountByPasswordPath,
httputil.MakeInternalAPI("queryAccountByPassword", func(req *http.Request) util.JSONResponse {
request := api.QueryAccountByPasswordRequest{}

View file

@ -23,7 +23,6 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
@ -95,7 +94,7 @@ type pushersStatements struct {
// Returns nil error success.
func (s *pushersStatements) InsertPusher(
ctx context.Context, txn *sql.Tx, session_id int64,
pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
) error {
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
logrus.Debugf("Created pusher %d", session_id)

View file

@ -23,7 +23,6 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
@ -95,7 +94,7 @@ type pushersStatements struct {
// Returns nil error success.
func (s *pushersStatements) InsertPusher(
ctx context.Context, txn *sql.Tx, session_id int64,
pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
) error {
_, err := s.insertPusherStmt.ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
logrus.Debugf("Created pusher %d", session_id)

View file

@ -21,7 +21,6 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
type AccountDataTable interface {
@ -96,7 +95,7 @@ type ThreePIDTable interface {
}
type PusherTable interface {
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error
SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error)
DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error
DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error