mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-01 03:03:10 -06:00
Merge branch 'matrix-org:main' into main
This commit is contained in:
commit
d0666902f9
25
.github/workflows/dendrite.yml
vendored
25
.github/workflows/dendrite.yml
vendored
|
|
@ -73,6 +73,26 @@ jobs:
|
|||
timeout-minutes: 5
|
||||
name: Unit tests (Go ${{ matrix.go }})
|
||||
runs-on: ubuntu-latest
|
||||
# Service containers to run with `container-job`
|
||||
services:
|
||||
# Label used to access the service container
|
||||
postgres:
|
||||
# Docker Hub image
|
||||
image: postgres:13-alpine
|
||||
# Provide the password for postgres
|
||||
env:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: dendrite
|
||||
ports:
|
||||
# Maps tcp port 5432 on service container to the host
|
||||
- 5432:5432
|
||||
# Set health checks to wait until postgres has started
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
|
|
@ -92,6 +112,11 @@ jobs:
|
|||
restore-keys: |
|
||||
${{ runner.os }}-go${{ matrix.go }}-test-
|
||||
- run: go test ./...
|
||||
env:
|
||||
POSTGRES_HOST: localhost
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: dendrite
|
||||
|
||||
# build Dendrite for linux with different architectures and go versions
|
||||
build:
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ import (
|
|||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
|
||||
pineconeConnections "github.com/matrix-org/pinecone/connections"
|
||||
pineconeMulticast "github.com/matrix-org/pinecone/multicast"
|
||||
pineconeRouter "github.com/matrix-org/pinecone/router"
|
||||
pineconeSessions "github.com/matrix-org/pinecone/sessions"
|
||||
|
|
@ -71,11 +72,9 @@ type DendriteMonolith struct {
|
|||
PineconeRouter *pineconeRouter.Router
|
||||
PineconeMulticast *pineconeMulticast.Multicast
|
||||
PineconeQUIC *pineconeSessions.Sessions
|
||||
PineconeManager *pineconeConnections.ConnectionManager
|
||||
StorageDirectory string
|
||||
CacheDirectory string
|
||||
staticPeerURI string
|
||||
staticPeerMutex sync.RWMutex
|
||||
staticPeerAttempt chan struct{}
|
||||
listener net.Listener
|
||||
httpServer *http.Server
|
||||
processContext *process.ProcessContext
|
||||
|
|
@ -104,15 +103,8 @@ func (m *DendriteMonolith) SetMulticastEnabled(enabled bool) {
|
|||
}
|
||||
|
||||
func (m *DendriteMonolith) SetStaticPeer(uri string) {
|
||||
m.staticPeerMutex.Lock()
|
||||
m.staticPeerURI = strings.TrimSpace(uri)
|
||||
m.staticPeerMutex.Unlock()
|
||||
m.DisconnectType(int(pineconeRouter.PeerTypeRemote))
|
||||
if uri != "" {
|
||||
go func() {
|
||||
m.staticPeerAttempt <- struct{}{}
|
||||
}()
|
||||
}
|
||||
m.PineconeManager.RemovePeers()
|
||||
m.PineconeManager.AddPeer(strings.TrimSpace(uri))
|
||||
}
|
||||
|
||||
func (m *DendriteMonolith) DisconnectType(peertype int) {
|
||||
|
|
@ -210,43 +202,6 @@ func (m *DendriteMonolith) RegisterDevice(localpart, deviceID string) (string, e
|
|||
return loginRes.Device.AccessToken, nil
|
||||
}
|
||||
|
||||
func (m *DendriteMonolith) staticPeerConnect() {
|
||||
connected := map[string]bool{} // URI -> connected?
|
||||
attempt := func() {
|
||||
m.staticPeerMutex.RLock()
|
||||
uri := m.staticPeerURI
|
||||
m.staticPeerMutex.RUnlock()
|
||||
if uri == "" {
|
||||
return
|
||||
}
|
||||
for k := range connected {
|
||||
delete(connected, k)
|
||||
}
|
||||
for _, uri := range strings.Split(uri, ",") {
|
||||
connected[strings.TrimSpace(uri)] = false
|
||||
}
|
||||
for _, info := range m.PineconeRouter.Peers() {
|
||||
connected[info.URI] = true
|
||||
}
|
||||
for k, online := range connected {
|
||||
if !online {
|
||||
if err := conn.ConnectToPeer(m.PineconeRouter, k); err != nil {
|
||||
logrus.WithError(err).Error("Failed to connect to static peer")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-m.processContext.Context().Done():
|
||||
case <-m.staticPeerAttempt:
|
||||
attempt()
|
||||
case <-time.After(time.Second * 5):
|
||||
attempt()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// nolint:gocyclo
|
||||
func (m *DendriteMonolith) Start() {
|
||||
var err error
|
||||
|
|
@ -284,6 +239,7 @@ func (m *DendriteMonolith) Start() {
|
|||
m.PineconeRouter = pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
|
||||
m.PineconeQUIC = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), m.PineconeRouter, []string{"matrix"})
|
||||
m.PineconeMulticast = pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), m.PineconeRouter)
|
||||
m.PineconeManager = pineconeConnections.NewConnectionManager(m.PineconeRouter)
|
||||
|
||||
prefix := hex.EncodeToString(pk)
|
||||
cfg := &config.Dendrite{}
|
||||
|
|
@ -392,9 +348,6 @@ func (m *DendriteMonolith) Start() {
|
|||
|
||||
m.processContext = base.ProcessContext
|
||||
|
||||
m.staticPeerAttempt = make(chan struct{}, 1)
|
||||
go m.staticPeerConnect()
|
||||
|
||||
go func() {
|
||||
m.logger.Info("Listening on ", cfg.Global.ServerName)
|
||||
m.logger.Fatal(m.httpServer.Serve(m.PineconeQUIC.Protocol("matrix")))
|
||||
|
|
|
|||
|
|
@ -95,10 +95,10 @@ func SaveAccountData(
|
|||
}
|
||||
}
|
||||
|
||||
if dataType == "m.fully_read" {
|
||||
if dataType == "m.fully_read" || dataType == "m.push_rules" {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: jsonerror.Forbidden("Unable to set read marker"),
|
||||
JSON: jsonerror.Forbidden(fmt.Sprintf("Unable to modify %q using this API", dataType)),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -64,11 +64,6 @@ const (
|
|||
sessionIDLength = 24
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Register prometheus metrics. They must be registered to be exposed.
|
||||
prometheus.MustRegister(amtRegUsers)
|
||||
}
|
||||
|
||||
// sessionsDict keeps track of completed auth stages for each session.
|
||||
// It shouldn't be passed by value because it contains a mutex.
|
||||
type sessionsDict struct {
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ import (
|
|||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
|
@ -60,6 +61,8 @@ func Setup(
|
|||
extRoomsProvider api.ExtraPublicRoomsProvider,
|
||||
mscCfg *config.MSCs, natsClient *nats.Conn,
|
||||
) {
|
||||
prometheus.MustRegister(amtRegUsers, sendEventDuration)
|
||||
|
||||
rateLimits := httputil.NewRateLimits(&cfg.RateLimiting)
|
||||
userInteractiveAuth := auth.NewUserInteractive(userAPI, cfg)
|
||||
|
||||
|
|
|
|||
|
|
@ -46,10 +46,6 @@ var (
|
|||
userRoomSendMutexes sync.Map // (roomID+userID) -> mutex. mutexes to ensure correct ordering of sendEvents
|
||||
)
|
||||
|
||||
func init() {
|
||||
prometheus.MustRegister(sendEventDuration)
|
||||
}
|
||||
|
||||
var sendEventDuration = prometheus.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Namespace: "dendrite",
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ func RequestTurnServer(req *http.Request, device *api.Device, cfg *config.Client
|
|||
|
||||
if turnConfig.SharedSecret != "" {
|
||||
expiry := time.Now().Add(duration).Unix()
|
||||
resp.Username = fmt.Sprintf("%d:%s", expiry, device.UserID)
|
||||
mac := hmac.New(sha1.New, []byte(turnConfig.SharedSecret))
|
||||
_, err := mac.Write([]byte(resp.Username))
|
||||
|
||||
|
|
@ -60,7 +61,6 @@ func RequestTurnServer(req *http.Request, device *api.Device, cfg *config.Client
|
|||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
||||
resp.Username = fmt.Sprintf("%d:%s", expiry, device.UserID)
|
||||
resp.Password = base64.StdEncoding.EncodeToString(mac.Sum(nil))
|
||||
} else if turnConfig.Username != "" && turnConfig.Password != "" {
|
||||
resp.Username = turnConfig.Username
|
||||
|
|
|
|||
|
|
@ -1,156 +0,0 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const usage = `Usage: %s
|
||||
|
||||
Create a single endpoint URL which clients can be pointed at.
|
||||
|
||||
The client-server API in Dendrite is split across multiple processes
|
||||
which listen on multiple ports. You cannot point a Matrix client at
|
||||
any of those ports, as there will be unimplemented functionality.
|
||||
In addition, all client-server API processes start with the additional
|
||||
path prefix '/api', which Matrix clients will be unaware of.
|
||||
|
||||
This tool will proxy requests for all client-server URLs and forward
|
||||
them to their respective process. It will also add the '/api' path
|
||||
prefix to incoming requests.
|
||||
|
||||
THIS TOOL IS FOR TESTING AND NOT INTENDED FOR PRODUCTION USE.
|
||||
|
||||
Arguments:
|
||||
|
||||
`
|
||||
|
||||
var (
|
||||
syncServerURL = flag.String("sync-api-server-url", "", "The base URL of the listening 'dendrite-sync-api-server' process. E.g. 'http://localhost:4200'")
|
||||
clientAPIURL = flag.String("client-api-server-url", "", "The base URL of the listening 'dendrite-client-api-server' process. E.g. 'http://localhost:4321'")
|
||||
mediaAPIURL = flag.String("media-api-server-url", "", "The base URL of the listening 'dendrite-media-api-server' process. E.g. 'http://localhost:7779'")
|
||||
bindAddress = flag.String("bind-address", ":8008", "The listening port for the proxy.")
|
||||
certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS")
|
||||
keyFile = flag.String("tls-key", "", "The PEM private key to use for TLS")
|
||||
)
|
||||
|
||||
func makeProxy(targetURL string) (*httputil.ReverseProxy, error) {
|
||||
targetURL = strings.TrimSuffix(targetURL, "/")
|
||||
|
||||
// Check that we can parse the URL.
|
||||
_, err := url.Parse(targetURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &httputil.ReverseProxy{
|
||||
Director: func(req *http.Request) {
|
||||
// URL.Path() removes the % escaping from the path.
|
||||
// The % encoding will be added back when the url is encoded
|
||||
// when the request is forwarded.
|
||||
// This means that we will lose any unessecary escaping from the URL.
|
||||
// Pratically this means that any distinction between '%2F' and '/'
|
||||
// in the URL will be lost by the time it reaches the target.
|
||||
path := req.URL.Path
|
||||
log.WithFields(log.Fields{
|
||||
"path": path,
|
||||
"url": targetURL,
|
||||
"method": req.Method,
|
||||
}).Print("proxying request")
|
||||
newURL, err := url.Parse(targetURL)
|
||||
// Set the path separately as we need to preserve '#' characters
|
||||
// that would otherwise be interpreted as being the start of a URL
|
||||
// fragment.
|
||||
newURL.Path += path
|
||||
if err != nil {
|
||||
// We already checked that we can parse the URL
|
||||
// So this shouldn't ever get hit.
|
||||
panic(err)
|
||||
}
|
||||
// Copy the query parameters from the request.
|
||||
newURL.RawQuery = req.URL.RawQuery
|
||||
req.URL = newURL
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, usage, os.Args[0])
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if *syncServerURL == "" {
|
||||
flag.Usage()
|
||||
fmt.Fprintln(os.Stderr, "no --sync-api-server-url specified.")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if *clientAPIURL == "" {
|
||||
flag.Usage()
|
||||
fmt.Fprintln(os.Stderr, "no --client-api-server-url specified.")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if *mediaAPIURL == "" {
|
||||
flag.Usage()
|
||||
fmt.Fprintln(os.Stderr, "no --media-api-server-url specified.")
|
||||
os.Exit(1)
|
||||
}
|
||||
syncProxy, err := makeProxy(*syncServerURL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
clientProxy, err := makeProxy(*clientAPIURL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
mediaProxy, err := makeProxy(*mediaAPIURL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
http.Handle("/_matrix/client/r0/sync", syncProxy)
|
||||
http.Handle("/_matrix/media/v1/", mediaProxy)
|
||||
http.Handle("/", clientProxy)
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: *bindAddress,
|
||||
ReadTimeout: 1 * time.Minute, // how long we wait for the client to send the entire request (after connection accept)
|
||||
WriteTimeout: 5 * time.Minute, // how long the proxy has to write the full response
|
||||
}
|
||||
|
||||
fmt.Println("Proxying requests to:")
|
||||
fmt.Println(" /_matrix/client/r0/sync => ", *syncServerURL+"/api/_matrix/client/r0/sync")
|
||||
fmt.Println(" /_matrix/media/v1 => ", *mediaAPIURL+"/api/_matrix/media/v1")
|
||||
fmt.Println(" /* => ", *clientAPIURL+"/api/*")
|
||||
fmt.Println("Listening on ", *bindAddress)
|
||||
if *certFile != "" && *keyFile != "" {
|
||||
panic(srv.ListenAndServeTLS(*certFile, *keyFile))
|
||||
} else {
|
||||
panic(srv.ListenAndServe())
|
||||
}
|
||||
}
|
||||
|
|
@ -25,7 +25,6 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
|
@ -47,6 +46,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/userapi"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
pineconeConnections "github.com/matrix-org/pinecone/connections"
|
||||
pineconeMulticast "github.com/matrix-org/pinecone/multicast"
|
||||
pineconeRouter "github.com/matrix-org/pinecone/router"
|
||||
pineconeSessions "github.com/matrix-org/pinecone/sessions"
|
||||
|
|
@ -90,6 +90,13 @@ func main() {
|
|||
}
|
||||
|
||||
pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
|
||||
pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
|
||||
pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter)
|
||||
pManager := pineconeConnections.NewConnectionManager(pRouter)
|
||||
pMulticast.Start()
|
||||
if instancePeer != nil && *instancePeer != "" {
|
||||
pManager.AddPeer(*instancePeer)
|
||||
}
|
||||
|
||||
go func() {
|
||||
listener, err := net.Listen("tcp", *instanceListen)
|
||||
|
|
@ -119,36 +126,6 @@ func main() {
|
|||
}
|
||||
}()
|
||||
|
||||
pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
|
||||
pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter)
|
||||
pMulticast.Start()
|
||||
|
||||
connectToStaticPeer := func() {
|
||||
connected := map[string]bool{} // URI -> connected?
|
||||
for _, uri := range strings.Split(*instancePeer, ",") {
|
||||
connected[strings.TrimSpace(uri)] = false
|
||||
}
|
||||
attempt := func() {
|
||||
for k := range connected {
|
||||
connected[k] = false
|
||||
}
|
||||
for _, info := range pRouter.Peers() {
|
||||
connected[info.URI] = true
|
||||
}
|
||||
for k, online := range connected {
|
||||
if !online {
|
||||
if err := conn.ConnectToPeer(pRouter, k); err != nil {
|
||||
logrus.WithError(err).Error("Failed to connect to static peer")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for {
|
||||
attempt()
|
||||
time.Sleep(time.Second * 5)
|
||||
}
|
||||
}
|
||||
|
||||
cfg := &config.Dendrite{}
|
||||
cfg.Defaults(true)
|
||||
cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk))
|
||||
|
|
@ -268,7 +245,6 @@ func main() {
|
|||
Handler: pMux,
|
||||
}
|
||||
|
||||
go connectToStaticPeer()
|
||||
go func() {
|
||||
pubkey := pRouter.PublicKey()
|
||||
logrus.Info("Listening on ", hex.EncodeToString(pubkey[:]))
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ import (
|
|||
"encoding/hex"
|
||||
"fmt"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/matrix-org/dendrite/appservice"
|
||||
|
|
@ -44,6 +43,7 @@ import (
|
|||
|
||||
_ "github.com/matrix-org/go-sqlite3-js"
|
||||
|
||||
pineconeConnections "github.com/matrix-org/pinecone/connections"
|
||||
pineconeRouter "github.com/matrix-org/pinecone/router"
|
||||
pineconeSessions "github.com/matrix-org/pinecone/sessions"
|
||||
)
|
||||
|
|
@ -154,6 +154,8 @@ func startup() {
|
|||
|
||||
pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
|
||||
pSessions := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
|
||||
pManager := pineconeConnections.NewConnectionManager(pRouter)
|
||||
pManager.AddPeer("wss://pinecone.matrix.org/public")
|
||||
|
||||
cfg := &config.Dendrite{}
|
||||
cfg.Defaults(true)
|
||||
|
|
@ -237,20 +239,4 @@ func startup() {
|
|||
}
|
||||
s.ListenAndServe("fetch")
|
||||
}()
|
||||
|
||||
// Connect to the static peer
|
||||
go func() {
|
||||
for {
|
||||
if pRouter.PeerCount(pineconeRouter.PeerTypeRemote) == 0 {
|
||||
if err := conn.ConnectToPeer(pRouter, publicPeer); err != nil {
|
||||
logrus.WithError(err).Error("Failed to connect to static peer")
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-base.ProcessContext.Context().Done():
|
||||
return
|
||||
case <-time.After(time.Second * 5):
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,138 +0,0 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const usage = `Usage: %s
|
||||
|
||||
Create a single endpoint URL which remote matrix servers can be pointed at.
|
||||
|
||||
The server-server API in Dendrite is split across multiple processes
|
||||
which listen on multiple ports. You cannot point a Matrix server at
|
||||
any of those ports, as there will be unimplemented functionality.
|
||||
In addition, all server-server API processes start with the additional
|
||||
path prefix '/api', which Matrix servers will be unaware of.
|
||||
|
||||
This tool will proxy requests for all server-server URLs and forward
|
||||
them to their respective process. It will also add the '/api' path
|
||||
prefix to incoming requests.
|
||||
|
||||
THIS TOOL IS FOR TESTING AND NOT INTENDED FOR PRODUCTION USE.
|
||||
|
||||
Arguments:
|
||||
|
||||
`
|
||||
|
||||
var (
|
||||
federationAPIURL = flag.String("federation-api-url", "", "The base URL of the listening 'dendrite-federation-api-server' process. E.g. 'http://localhost:4200'")
|
||||
mediaAPIURL = flag.String("media-api-server-url", "", "The base URL of the listening 'dendrite-media-api-server' process. E.g. 'http://localhost:7779'")
|
||||
bindAddress = flag.String("bind-address", ":8448", "The listening port for the proxy.")
|
||||
certFile = flag.String("tls-cert", "server.crt", "The PEM formatted X509 certificate to use for TLS")
|
||||
keyFile = flag.String("tls-key", "server.key", "The PEM private key to use for TLS")
|
||||
)
|
||||
|
||||
func makeProxy(targetURL string) (*httputil.ReverseProxy, error) {
|
||||
if !strings.HasSuffix(targetURL, "/") {
|
||||
targetURL += "/"
|
||||
}
|
||||
// Check that we can parse the URL.
|
||||
_, err := url.Parse(targetURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &httputil.ReverseProxy{
|
||||
Director: func(req *http.Request) {
|
||||
// URL.Path() removes the % escaping from the path.
|
||||
// The % encoding will be added back when the url is encoded
|
||||
// when the request is forwarded.
|
||||
// This means that we will lose any unessecary escaping from the URL.
|
||||
// Pratically this means that any distinction between '%2F' and '/'
|
||||
// in the URL will be lost by the time it reaches the target.
|
||||
path := req.URL.Path
|
||||
log.WithFields(log.Fields{
|
||||
"path": path,
|
||||
"url": targetURL,
|
||||
"method": req.Method,
|
||||
}).Print("proxying request")
|
||||
newURL, err := url.Parse(targetURL + path)
|
||||
if err != nil {
|
||||
// We already checked that we can parse the URL
|
||||
// So this shouldn't ever get hit.
|
||||
panic(err)
|
||||
}
|
||||
// Copy the query parameters from the request.
|
||||
newURL.RawQuery = req.URL.RawQuery
|
||||
req.URL = newURL
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, usage, os.Args[0])
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if *federationAPIURL == "" {
|
||||
flag.Usage()
|
||||
fmt.Fprintln(os.Stderr, "no --federation-api-url specified.")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if *mediaAPIURL == "" {
|
||||
flag.Usage()
|
||||
fmt.Fprintln(os.Stderr, "no --media-api-server-url specified.")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
federationProxy, err := makeProxy(*federationAPIURL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
mediaProxy, err := makeProxy(*mediaAPIURL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
http.Handle("/_matrix/media/v1/", mediaProxy)
|
||||
http.Handle("/", federationProxy)
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: *bindAddress,
|
||||
ReadTimeout: 1 * time.Minute, // how long we wait for the client to send the entire request (after connection accept)
|
||||
WriteTimeout: 5 * time.Minute, // how long the proxy has to write the full response
|
||||
}
|
||||
|
||||
fmt.Println("Proxying requests to:")
|
||||
fmt.Println(" /_matrix/media/v1 => ", *mediaAPIURL+"/api/_matrix/media/v1")
|
||||
fmt.Println(" /* => ", *federationAPIURL+"/api/*")
|
||||
fmt.Println("Listening on ", *bindAddress)
|
||||
panic(srv.ListenAndServeTLS(*certFile, *keyFile))
|
||||
}
|
||||
|
|
@ -29,6 +29,7 @@ import (
|
|||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
|
@ -53,6 +54,10 @@ func Setup(
|
|||
servers federationAPI.ServersInRoomProvider,
|
||||
producer *producers.SyncAPIProducer,
|
||||
) {
|
||||
prometheus.MustRegister(
|
||||
pduCountTotal, eduCountTotal,
|
||||
)
|
||||
|
||||
v2keysmux := keyMux.PathPrefix("/v2").Subrouter()
|
||||
v1fedmux := fedMux.PathPrefix("/v1").Subrouter()
|
||||
v2fedmux := fedMux.PathPrefix("/v2").Subrouter()
|
||||
|
|
|
|||
|
|
@ -74,12 +74,6 @@ var (
|
|||
)
|
||||
)
|
||||
|
||||
func init() {
|
||||
prometheus.MustRegister(
|
||||
pduCountTotal, eduCountTotal,
|
||||
)
|
||||
}
|
||||
|
||||
var inFlightTxnsPerOrigin sync.Map // transaction ID -> chan util.JSONResponse
|
||||
|
||||
// Send implements /_matrix/federation/v1/send/{txnID}
|
||||
|
|
|
|||
38
go.mod
38
go.mod
|
|
@ -12,20 +12,20 @@ require (
|
|||
github.com/MFAshby/stdemuxerhook v1.0.0
|
||||
github.com/Masterminds/semver/v3 v3.1.1
|
||||
github.com/codeclysm/extract v2.2.0+incompatible
|
||||
github.com/containerd/containerd v1.5.9 // indirect
|
||||
github.com/docker/docker v20.10.12+incompatible
|
||||
github.com/containerd/containerd v1.6.2 // indirect
|
||||
github.com/docker/docker v20.10.14+incompatible
|
||||
github.com/docker/go-connections v0.4.0
|
||||
github.com/frankban/quicktest v1.14.0 // indirect
|
||||
github.com/getsentry/sentry-go v0.12.0
|
||||
github.com/frankban/quicktest v1.14.3 // indirect
|
||||
github.com/getsentry/sentry-go v0.13.0
|
||||
github.com/gologme/log v1.3.0
|
||||
github.com/google/go-cmp v0.5.6
|
||||
github.com/google/uuid v1.2.0
|
||||
github.com/google/go-cmp v0.5.7
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/gorilla/mux v1.8.0
|
||||
github.com/gorilla/websocket v1.4.2
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/h2non/filetype v1.1.3 // indirect
|
||||
github.com/hashicorp/golang-lru v0.5.4
|
||||
github.com/juju/testing v0.0.0-20211215003918-77eb13d6cad2 // indirect
|
||||
github.com/lib/pq v1.10.4
|
||||
github.com/juju/testing v0.0.0-20220203020004-a0ff61f03494 // indirect
|
||||
github.com/lib/pq v1.10.5
|
||||
github.com/libp2p/go-libp2p v0.13.0
|
||||
github.com/libp2p/go-libp2p-circuit v0.4.0
|
||||
github.com/libp2p/go-libp2p-core v0.8.3
|
||||
|
|
@ -36,18 +36,18 @@ require (
|
|||
github.com/libp2p/go-libp2p-record v0.1.3
|
||||
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
|
||||
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
|
||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d
|
||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
|
||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5
|
||||
github.com/matrix-org/pinecone v0.0.0-20220404141326-e526fa82f79d
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220408160933-cf558306b56f
|
||||
github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48
|
||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
||||
github.com/mattn/go-sqlite3 v1.14.10
|
||||
github.com/morikuni/aec v1.0.0 // indirect
|
||||
github.com/nats-io/nats-server/v2 v2.7.4-0.20220309205833-773636c1c5bb
|
||||
github.com/nats-io/nats.go v1.13.1-0.20220308171302-2f2f6968e98d
|
||||
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
||||
github.com/ngrok/sqlmw v0.0.0-20211220175533-9d16fdc47b31
|
||||
github.com/opencontainers/image-spec v1.0.2 // indirect
|
||||
github.com/opentracing/opentracing-go v1.2.0
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/pkg/errors v0.9.1
|
||||
|
|
@ -60,14 +60,14 @@ require (
|
|||
github.com/uber/jaeger-lib v2.4.1+incompatible
|
||||
github.com/yggdrasil-network/yggdrasil-go v0.4.3
|
||||
go.uber.org/atomic v1.9.0
|
||||
golang.org/x/crypto v0.0.0-20220214200702-86341886e292
|
||||
golang.org/x/image v0.0.0-20211028202545-6944b10bf410
|
||||
golang.org/x/mobile v0.0.0-20220325161704-447654d348e3
|
||||
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd
|
||||
golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29
|
||||
golang.org/x/image v0.0.0-20220321031419-a8550c1d254a
|
||||
golang.org/x/mobile v0.0.0-20220407111146-e579adbbc4a2
|
||||
golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3
|
||||
golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 // indirect
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211
|
||||
gopkg.in/h2non/bimg.v1 v1.1.5
|
||||
gopkg.in/h2non/bimg.v1 v1.1.9
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
|
||||
nhooyr.io/websocket v1.8.7
|
||||
)
|
||||
|
||||
|
|
|
|||
86
internal/caching/cache_lazy_load_members.go
Normal file
86
internal/caching/cache_lazy_load_members.go
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
package caching
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
)
|
||||
|
||||
const (
|
||||
LazyLoadCacheName = "lazy_load_members"
|
||||
LazyLoadCacheMaxEntries = 128
|
||||
LazyLoadCacheMaxUserEntries = 128
|
||||
LazyLoadCacheMutable = true
|
||||
LazyLoadCacheMaxAge = time.Minute * 30
|
||||
)
|
||||
|
||||
type LazyLoadCache struct {
|
||||
// InMemoryLRUCachePartition containing other InMemoryLRUCachePartitions
|
||||
// with the actual cached members
|
||||
userCaches *InMemoryLRUCachePartition
|
||||
}
|
||||
|
||||
// NewLazyLoadCache creates a new LazyLoadCache.
|
||||
func NewLazyLoadCache() (*LazyLoadCache, error) {
|
||||
cache, err := NewInMemoryLRUCachePartition(
|
||||
LazyLoadCacheName,
|
||||
LazyLoadCacheMutable,
|
||||
LazyLoadCacheMaxEntries,
|
||||
LazyLoadCacheMaxAge,
|
||||
true,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go cacheCleaner(cache)
|
||||
return &LazyLoadCache{
|
||||
userCaches: cache,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *LazyLoadCache) lazyLoadCacheForUser(device *userapi.Device) (*InMemoryLRUCachePartition, error) {
|
||||
cacheName := fmt.Sprintf("%s/%s", device.UserID, device.ID)
|
||||
userCache, ok := c.userCaches.Get(cacheName)
|
||||
if ok && userCache != nil {
|
||||
if cache, ok := userCache.(*InMemoryLRUCachePartition); ok {
|
||||
return cache, nil
|
||||
}
|
||||
}
|
||||
cache, err := NewInMemoryLRUCachePartition(
|
||||
LazyLoadCacheName,
|
||||
LazyLoadCacheMutable,
|
||||
LazyLoadCacheMaxUserEntries,
|
||||
LazyLoadCacheMaxAge,
|
||||
false,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.userCaches.Set(cacheName, cache)
|
||||
go cacheCleaner(cache)
|
||||
return cache, nil
|
||||
}
|
||||
|
||||
func (c *LazyLoadCache) StoreLazyLoadedUser(device *userapi.Device, roomID, userID, eventID string) {
|
||||
cache, err := c.lazyLoadCacheForUser(device)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cacheKey := fmt.Sprintf("%s/%s/%s/%s", device.UserID, device.ID, roomID, userID)
|
||||
cache.Set(cacheKey, eventID)
|
||||
}
|
||||
|
||||
func (c *LazyLoadCache) IsLazyLoadedUserCached(device *userapi.Device, roomID, userID string) (string, bool) {
|
||||
cache, err := c.lazyLoadCacheForUser(device)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
cacheKey := fmt.Sprintf("%s/%s/%s/%s", device.UserID, device.ID, roomID, userID)
|
||||
val, ok := cache.Get(cacheKey)
|
||||
if !ok {
|
||||
return "", ok
|
||||
}
|
||||
return val.(string), ok
|
||||
}
|
||||
|
|
@ -169,8 +169,9 @@ func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request)
|
|||
return promhttp.InstrumentHandlerCounter(
|
||||
promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: metricsName,
|
||||
Help: "Total number of http requests for HTML resources",
|
||||
Name: metricsName,
|
||||
Help: "Total number of http requests for HTML resources",
|
||||
Namespace: "dendrite",
|
||||
},
|
||||
[]string{"code"},
|
||||
),
|
||||
|
|
@ -201,7 +202,28 @@ func MakeInternalAPI(metricsName string, f func(*http.Request) util.JSONResponse
|
|||
h.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
return http.HandlerFunc(withSpan)
|
||||
return promhttp.InstrumentHandlerCounter(
|
||||
promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: metricsName + "_requests_total",
|
||||
Help: "Total number of internal API calls",
|
||||
Namespace: "dendrite",
|
||||
},
|
||||
[]string{"code"},
|
||||
),
|
||||
promhttp.InstrumentHandlerResponseSize(
|
||||
promauto.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Namespace: "dendrite",
|
||||
Name: metricsName + "_response_size_bytes",
|
||||
Help: "A histogram of response sizes for requests.",
|
||||
Buckets: []float64{200, 500, 900, 1500, 5000, 15000, 50000, 100000},
|
||||
},
|
||||
[]string{},
|
||||
),
|
||||
http.HandlerFunc(withSpan),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// MakeFedAPI makes an http.Handler that checks matrix federation authentication.
|
||||
|
|
|
|||
|
|
@ -3,8 +3,6 @@ package pushgateway
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// A Client is how interactions with a Push Gateway is done.
|
||||
|
|
@ -47,11 +45,11 @@ type Counts struct {
|
|||
}
|
||||
|
||||
type Device struct {
|
||||
AppID string `json:"app_id"` // Required
|
||||
Data map[string]interface{} `json:"data"` // Required. UNSPEC: Sytests require this to allow unknown keys.
|
||||
PushKey string `json:"pushkey"` // Required
|
||||
PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"`
|
||||
Tweaks map[string]interface{} `json:"tweaks,omitempty"`
|
||||
AppID string `json:"app_id"` // Required
|
||||
Data map[string]interface{} `json:"data"` // Required. UNSPEC: Sytests require this to allow unknown keys.
|
||||
PushKey string `json:"pushkey"` // Required
|
||||
PushKeyTS int64 `json:"pushkey_ts,omitempty"`
|
||||
Tweaks map[string]interface{} `json:"tweaks,omitempty"`
|
||||
}
|
||||
|
||||
type Prio string
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ func AddPublicRoutes(
|
|||
userAPI userapi.UserInternalAPI,
|
||||
client *gomatrixserverlib.Client,
|
||||
) {
|
||||
mediaDB, err := storage.Open(&cfg.Database)
|
||||
mediaDB, err := storage.NewMediaAPIDatasource(&cfg.Database)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panicf("failed to connect to media db")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
|
|
@ -311,6 +312,26 @@ func (r *uploadRequest) storeFileAndMetadata(
|
|||
}
|
||||
|
||||
go func() {
|
||||
file, err := os.Open(string(finalPath))
|
||||
if err != nil {
|
||||
r.Logger.WithError(err).Error("unable to open file")
|
||||
return
|
||||
}
|
||||
defer file.Close() // nolint: errcheck
|
||||
// http.DetectContentType only needs 512 bytes
|
||||
buf := make([]byte, 512)
|
||||
_, err = file.Read(buf)
|
||||
if err != nil {
|
||||
r.Logger.WithError(err).Error("unable to read file")
|
||||
return
|
||||
}
|
||||
// Check if we need to generate thumbnails
|
||||
fileType := http.DetectContentType(buf)
|
||||
if !strings.HasPrefix(fileType, "image") {
|
||||
r.Logger.WithField("contentType", fileType).Debugf("uploaded file is not an image or can not be thumbnailed, not generating thumbnails")
|
||||
return
|
||||
}
|
||||
|
||||
busy, err := thumbnailer.GenerateThumbnails(
|
||||
context.Background(), finalPath, thumbnailSizes, r.MediaMetadata,
|
||||
activeThumbnailGeneration, maxThumbnailGenerators, db, r.Logger,
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ func Test_uploadRequest_doUpload(t *testing.T) {
|
|||
_ = os.Mkdir(testdataPath, os.ModePerm)
|
||||
defer fileutils.RemoveDir(types.Path(testdataPath), nil)
|
||||
|
||||
db, err := storage.Open(&config.DatabaseOptions{
|
||||
db, err := storage.NewMediaAPIDatasource(&config.DatabaseOptions{
|
||||
ConnectionString: "file::memory:?cache=shared",
|
||||
MaxOpenConnections: 100,
|
||||
MaxIdleConnections: 2,
|
||||
|
|
|
|||
|
|
@ -22,9 +22,17 @@ import (
|
|||
)
|
||||
|
||||
type Database interface {
|
||||
MediaRepository
|
||||
Thumbnails
|
||||
}
|
||||
|
||||
type MediaRepository interface {
|
||||
StoreMediaMetadata(ctx context.Context, mediaMetadata *types.MediaMetadata) error
|
||||
GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error)
|
||||
GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error)
|
||||
}
|
||||
|
||||
type Thumbnails interface {
|
||||
StoreThumbnail(ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata) error
|
||||
GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error)
|
||||
GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) ([]*types.ThumbnailMetadata, error)
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ import (
|
|||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
|
@ -69,24 +71,25 @@ type mediaStatements struct {
|
|||
selectMediaByHashStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *mediaStatements) prepare(db *sql.DB) (err error) {
|
||||
_, err = db.Exec(mediaSchema)
|
||||
func NewPostgresMediaRepositoryTable(db *sql.DB) (tables.MediaRepository, error) {
|
||||
s := &mediaStatements{}
|
||||
_, err := db.Exec(mediaSchema)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return statementList{
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertMediaStmt, insertMediaSQL},
|
||||
{&s.selectMediaStmt, selectMediaSQL},
|
||||
{&s.selectMediaByHashStmt, selectMediaByHashSQL},
|
||||
}.prepare(db)
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *mediaStatements) insertMedia(
|
||||
ctx context.Context, mediaMetadata *types.MediaMetadata,
|
||||
func (s *mediaStatements) InsertMedia(
|
||||
ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata,
|
||||
) error {
|
||||
mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
|
||||
_, err := s.insertMediaStmt.ExecContext(
|
||||
mediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
|
||||
_, err := sqlutil.TxStmtContext(ctx, txn, s.insertMediaStmt).ExecContext(
|
||||
ctx,
|
||||
mediaMetadata.MediaID,
|
||||
mediaMetadata.Origin,
|
||||
|
|
@ -100,14 +103,14 @@ func (s *mediaStatements) insertMedia(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *mediaStatements) selectMedia(
|
||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
func (s *mediaStatements) SelectMedia(
|
||||
ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) (*types.MediaMetadata, error) {
|
||||
mediaMetadata := types.MediaMetadata{
|
||||
MediaID: mediaID,
|
||||
Origin: mediaOrigin,
|
||||
}
|
||||
err := s.selectMediaStmt.QueryRowContext(
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaStmt).QueryRowContext(
|
||||
ctx, mediaMetadata.MediaID, mediaMetadata.Origin,
|
||||
).Scan(
|
||||
&mediaMetadata.ContentType,
|
||||
|
|
@ -120,14 +123,14 @@ func (s *mediaStatements) selectMedia(
|
|||
return &mediaMetadata, err
|
||||
}
|
||||
|
||||
func (s *mediaStatements) selectMediaByHash(
|
||||
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
|
||||
func (s *mediaStatements) SelectMediaByHash(
|
||||
ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) (*types.MediaMetadata, error) {
|
||||
mediaMetadata := types.MediaMetadata{
|
||||
Base64Hash: mediaHash,
|
||||
Origin: mediaOrigin,
|
||||
}
|
||||
err := s.selectMediaStmt.QueryRowContext(
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaByHashStmt).QueryRowContext(
|
||||
ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin,
|
||||
).Scan(
|
||||
&mediaMetadata.ContentType,
|
||||
|
|
|
|||
46
mediaapi/storage/postgres/mediaapi.go
Normal file
46
mediaapi/storage/postgres/mediaapi.go
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
// Import the postgres database driver.
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/mediaapi/storage/shared"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
)
|
||||
|
||||
// NewDatabase opens a postgres database.
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) {
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mediaRepo, err := NewPostgresMediaRepositoryTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
thumbnails, err := NewPostgresThumbnailsTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &shared.Database{
|
||||
MediaRepository: mediaRepo,
|
||||
Thumbnails: thumbnails,
|
||||
DB: db,
|
||||
Writer: sqlutil.NewExclusiveWriter(),
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// FIXME: This should be made internal!
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement.
|
||||
type statementList []struct {
|
||||
statement **sql.Stmt
|
||||
sql string
|
||||
}
|
||||
|
||||
// prepare the SQL for each statement in the list and assign the result to the prepared statement.
|
||||
func (s statementList) prepare(db *sql.DB) (err error) {
|
||||
for _, statement := range s {
|
||||
if *statement.statement, err = db.Prepare(statement.sql); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
@ -21,6 +21,8 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
|
@ -63,7 +65,7 @@ SELECT content_type, file_size_bytes, creation_ts FROM mediaapi_thumbnail WHERE
|
|||
|
||||
// Note: this selects all thumbnails for a media_origin and media_id
|
||||
const selectThumbnailsSQL = `
|
||||
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2
|
||||
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 ORDER BY creation_ts ASC
|
||||
`
|
||||
|
||||
type thumbnailStatements struct {
|
||||
|
|
@ -72,24 +74,25 @@ type thumbnailStatements struct {
|
|||
selectThumbnailsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *thumbnailStatements) prepare(db *sql.DB) (err error) {
|
||||
_, err = db.Exec(thumbnailSchema)
|
||||
func NewPostgresThumbnailsTable(db *sql.DB) (tables.Thumbnails, error) {
|
||||
s := &thumbnailStatements{}
|
||||
_, err := db.Exec(thumbnailSchema)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return statementList{
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertThumbnailStmt, insertThumbnailSQL},
|
||||
{&s.selectThumbnailStmt, selectThumbnailSQL},
|
||||
{&s.selectThumbnailsStmt, selectThumbnailsSQL},
|
||||
}.prepare(db)
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *thumbnailStatements) insertThumbnail(
|
||||
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
|
||||
func (s *thumbnailStatements) InsertThumbnail(
|
||||
ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata,
|
||||
) error {
|
||||
thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
|
||||
_, err := s.insertThumbnailStmt.ExecContext(
|
||||
thumbnailMetadata.MediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
|
||||
_, err := sqlutil.TxStmtContext(ctx, txn, s.insertThumbnailStmt).ExecContext(
|
||||
ctx,
|
||||
thumbnailMetadata.MediaMetadata.MediaID,
|
||||
thumbnailMetadata.MediaMetadata.Origin,
|
||||
|
|
@ -103,8 +106,9 @@ func (s *thumbnailStatements) insertThumbnail(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *thumbnailStatements) selectThumbnail(
|
||||
func (s *thumbnailStatements) SelectThumbnail(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
mediaID types.MediaID,
|
||||
mediaOrigin gomatrixserverlib.ServerName,
|
||||
width, height int,
|
||||
|
|
@ -121,7 +125,7 @@ func (s *thumbnailStatements) selectThumbnail(
|
|||
ResizeMethod: resizeMethod,
|
||||
},
|
||||
}
|
||||
err := s.selectThumbnailStmt.QueryRowContext(
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailStmt).QueryRowContext(
|
||||
ctx,
|
||||
thumbnailMetadata.MediaMetadata.MediaID,
|
||||
thumbnailMetadata.MediaMetadata.Origin,
|
||||
|
|
@ -136,10 +140,10 @@ func (s *thumbnailStatements) selectThumbnail(
|
|||
return &thumbnailMetadata, err
|
||||
}
|
||||
|
||||
func (s *thumbnailStatements) selectThumbnails(
|
||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
func (s *thumbnailStatements) SelectThumbnails(
|
||||
ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) ([]*types.ThumbnailMetadata, error) {
|
||||
rows, err := s.selectThumbnailsStmt.QueryContext(
|
||||
rows, err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailsStmt).QueryContext(
|
||||
ctx, mediaID, mediaOrigin,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
|
|
@ -13,54 +12,38 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
package shared
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
// Import the postgres database driver.
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// Database is used to store metadata about a repository of media files.
|
||||
type Database struct {
|
||||
statements statements
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// Open opens a postgres database.
|
||||
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||
var d Database
|
||||
var err error
|
||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.statements.prepare(d.db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &d, nil
|
||||
DB *sql.DB
|
||||
Writer sqlutil.Writer
|
||||
MediaRepository tables.MediaRepository
|
||||
Thumbnails tables.Thumbnails
|
||||
}
|
||||
|
||||
// StoreMediaMetadata inserts the metadata about the uploaded media into the database.
|
||||
// Returns an error if the combination of MediaID and Origin are not unique in the table.
|
||||
func (d *Database) StoreMediaMetadata(
|
||||
ctx context.Context, mediaMetadata *types.MediaMetadata,
|
||||
) error {
|
||||
return d.statements.media.insertMedia(ctx, mediaMetadata)
|
||||
func (d Database) StoreMediaMetadata(ctx context.Context, mediaMetadata *types.MediaMetadata) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.MediaRepository.InsertMedia(ctx, txn, mediaMetadata)
|
||||
})
|
||||
}
|
||||
|
||||
// GetMediaMetadata returns metadata about media stored on this server.
|
||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
||||
// Returns nil metadata if there is no metadata associated with this media.
|
||||
func (d *Database) GetMediaMetadata(
|
||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) (*types.MediaMetadata, error) {
|
||||
mediaMetadata, err := d.statements.media.selectMedia(ctx, mediaID, mediaOrigin)
|
||||
func (d Database) GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) {
|
||||
mediaMetadata, err := d.MediaRepository.SelectMedia(ctx, nil, mediaID, mediaOrigin)
|
||||
if err != nil && err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
|
|
@ -70,10 +53,8 @@ func (d *Database) GetMediaMetadata(
|
|||
// GetMediaMetadataByHash returns metadata about media stored on this server.
|
||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
||||
// Returns nil metadata if there is no metadata associated with this media.
|
||||
func (d *Database) GetMediaMetadataByHash(
|
||||
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) (*types.MediaMetadata, error) {
|
||||
mediaMetadata, err := d.statements.media.selectMediaByHash(ctx, mediaHash, mediaOrigin)
|
||||
func (d Database) GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) {
|
||||
mediaMetadata, err := d.MediaRepository.SelectMediaByHash(ctx, nil, mediaHash, mediaOrigin)
|
||||
if err != nil && err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
|
|
@ -82,40 +63,36 @@ func (d *Database) GetMediaMetadataByHash(
|
|||
|
||||
// StoreThumbnail inserts the metadata about the thumbnail into the database.
|
||||
// Returns an error if the combination of MediaID and Origin are not unique in the table.
|
||||
func (d *Database) StoreThumbnail(
|
||||
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
|
||||
) error {
|
||||
return d.statements.thumbnail.insertThumbnail(ctx, thumbnailMetadata)
|
||||
func (d Database) StoreThumbnail(ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.Thumbnails.InsertThumbnail(ctx, txn, thumbnailMetadata)
|
||||
})
|
||||
}
|
||||
|
||||
// GetThumbnail returns metadata about a specific thumbnail.
|
||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
||||
// Returns nil metadata if there is no metadata associated with this thumbnail.
|
||||
func (d *Database) GetThumbnail(
|
||||
ctx context.Context,
|
||||
mediaID types.MediaID,
|
||||
mediaOrigin gomatrixserverlib.ServerName,
|
||||
width, height int,
|
||||
resizeMethod string,
|
||||
) (*types.ThumbnailMetadata, error) {
|
||||
thumbnailMetadata, err := d.statements.thumbnail.selectThumbnail(
|
||||
ctx, mediaID, mediaOrigin, width, height, resizeMethod,
|
||||
)
|
||||
if err != nil && err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
func (d Database) GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error) {
|
||||
metadata, err := d.Thumbnails.SelectThumbnail(ctx, nil, mediaID, mediaOrigin, width, height, resizeMethod)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return thumbnailMetadata, err
|
||||
return metadata, err
|
||||
}
|
||||
|
||||
// GetThumbnails returns metadata about all thumbnails for a specific media stored on this server.
|
||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
||||
// Returns nil metadata if there are no thumbnails associated with this media.
|
||||
func (d *Database) GetThumbnails(
|
||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) ([]*types.ThumbnailMetadata, error) {
|
||||
thumbnails, err := d.statements.thumbnail.selectThumbnails(ctx, mediaID, mediaOrigin)
|
||||
if err != nil && err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
func (d Database) GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) ([]*types.ThumbnailMetadata, error) {
|
||||
metadatas, err := d.Thumbnails.SelectThumbnails(ctx, nil, mediaID, mediaOrigin)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return thumbnails, err
|
||||
return metadatas, err
|
||||
}
|
||||
|
|
@ -21,6 +21,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
|
@ -66,57 +67,53 @@ SELECT content_type, file_size_bytes, creation_ts, upload_name, media_id, user_i
|
|||
|
||||
type mediaStatements struct {
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
insertMediaStmt *sql.Stmt
|
||||
selectMediaStmt *sql.Stmt
|
||||
selectMediaByHashStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *mediaStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
||||
s.db = db
|
||||
s.writer = writer
|
||||
|
||||
_, err = db.Exec(mediaSchema)
|
||||
func NewSQLiteMediaRepositoryTable(db *sql.DB) (tables.MediaRepository, error) {
|
||||
s := &mediaStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(mediaSchema)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return statementList{
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertMediaStmt, insertMediaSQL},
|
||||
{&s.selectMediaStmt, selectMediaSQL},
|
||||
{&s.selectMediaByHashStmt, selectMediaByHashSQL},
|
||||
}.prepare(db)
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *mediaStatements) insertMedia(
|
||||
ctx context.Context, mediaMetadata *types.MediaMetadata,
|
||||
func (s *mediaStatements) InsertMedia(
|
||||
ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata,
|
||||
) error {
|
||||
mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
|
||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertMediaStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
mediaMetadata.MediaID,
|
||||
mediaMetadata.Origin,
|
||||
mediaMetadata.ContentType,
|
||||
mediaMetadata.FileSizeBytes,
|
||||
mediaMetadata.CreationTimestamp,
|
||||
mediaMetadata.UploadName,
|
||||
mediaMetadata.Base64Hash,
|
||||
mediaMetadata.UserID,
|
||||
)
|
||||
return err
|
||||
})
|
||||
mediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
|
||||
_, err := sqlutil.TxStmtContext(ctx, txn, s.insertMediaStmt).ExecContext(
|
||||
ctx,
|
||||
mediaMetadata.MediaID,
|
||||
mediaMetadata.Origin,
|
||||
mediaMetadata.ContentType,
|
||||
mediaMetadata.FileSizeBytes,
|
||||
mediaMetadata.CreationTimestamp,
|
||||
mediaMetadata.UploadName,
|
||||
mediaMetadata.Base64Hash,
|
||||
mediaMetadata.UserID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *mediaStatements) selectMedia(
|
||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
func (s *mediaStatements) SelectMedia(
|
||||
ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) (*types.MediaMetadata, error) {
|
||||
mediaMetadata := types.MediaMetadata{
|
||||
MediaID: mediaID,
|
||||
Origin: mediaOrigin,
|
||||
}
|
||||
err := s.selectMediaStmt.QueryRowContext(
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaStmt).QueryRowContext(
|
||||
ctx, mediaMetadata.MediaID, mediaMetadata.Origin,
|
||||
).Scan(
|
||||
&mediaMetadata.ContentType,
|
||||
|
|
@ -129,14 +126,14 @@ func (s *mediaStatements) selectMedia(
|
|||
return &mediaMetadata, err
|
||||
}
|
||||
|
||||
func (s *mediaStatements) selectMediaByHash(
|
||||
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
|
||||
func (s *mediaStatements) SelectMediaByHash(
|
||||
ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) (*types.MediaMetadata, error) {
|
||||
mediaMetadata := types.MediaMetadata{
|
||||
Base64Hash: mediaHash,
|
||||
Origin: mediaOrigin,
|
||||
}
|
||||
err := s.selectMediaStmt.QueryRowContext(
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaByHashStmt).QueryRowContext(
|
||||
ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin,
|
||||
).Scan(
|
||||
&mediaMetadata.ContentType,
|
||||
|
|
|
|||
|
|
@ -16,23 +16,30 @@
|
|||
package sqlite3
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
// Import the postgres database driver.
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/mediaapi/storage/shared"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
)
|
||||
|
||||
type statements struct {
|
||||
media mediaStatements
|
||||
thumbnail thumbnailStatements
|
||||
}
|
||||
|
||||
func (s *statements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
||||
if err = s.media.prepare(db, writer); err != nil {
|
||||
return
|
||||
// NewDatabase opens a SQLIte database.
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) {
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = s.thumbnail.prepare(db, writer); err != nil {
|
||||
return
|
||||
mediaRepo, err := NewSQLiteMediaRepositoryTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return
|
||||
thumbnails, err := NewSQLiteThumbnailsTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &shared.Database{
|
||||
MediaRepository: mediaRepo,
|
||||
Thumbnails: thumbnails,
|
||||
DB: db,
|
||||
Writer: sqlutil.NewExclusiveWriter(),
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// FIXME: This should be made internal!
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement.
|
||||
type statementList []struct {
|
||||
statement **sql.Stmt
|
||||
sql string
|
||||
}
|
||||
|
||||
// prepare the SQL for each statement in the list and assign the result to the prepared statement.
|
||||
func (s statementList) prepare(db *sql.DB) (err error) {
|
||||
for _, statement := range s {
|
||||
if *statement.statement, err = db.Prepare(statement.sql); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
@ -1,123 +0,0 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
// Import the postgres database driver.
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// Database is used to store metadata about a repository of media files.
|
||||
type Database struct {
|
||||
statements statements
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
}
|
||||
|
||||
// Open opens a postgres database.
|
||||
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||
d := Database{
|
||||
writer: sqlutil.NewExclusiveWriter(),
|
||||
}
|
||||
var err error
|
||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.statements.prepare(d.db, d.writer); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &d, nil
|
||||
}
|
||||
|
||||
// StoreMediaMetadata inserts the metadata about the uploaded media into the database.
|
||||
// Returns an error if the combination of MediaID and Origin are not unique in the table.
|
||||
func (d *Database) StoreMediaMetadata(
|
||||
ctx context.Context, mediaMetadata *types.MediaMetadata,
|
||||
) error {
|
||||
return d.statements.media.insertMedia(ctx, mediaMetadata)
|
||||
}
|
||||
|
||||
// GetMediaMetadata returns metadata about media stored on this server.
|
||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
||||
// Returns nil metadata if there is no metadata associated with this media.
|
||||
func (d *Database) GetMediaMetadata(
|
||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) (*types.MediaMetadata, error) {
|
||||
mediaMetadata, err := d.statements.media.selectMedia(ctx, mediaID, mediaOrigin)
|
||||
if err != nil && err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return mediaMetadata, err
|
||||
}
|
||||
|
||||
// GetMediaMetadataByHash returns metadata about media stored on this server.
|
||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
||||
// Returns nil metadata if there is no metadata associated with this media.
|
||||
func (d *Database) GetMediaMetadataByHash(
|
||||
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) (*types.MediaMetadata, error) {
|
||||
mediaMetadata, err := d.statements.media.selectMediaByHash(ctx, mediaHash, mediaOrigin)
|
||||
if err != nil && err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return mediaMetadata, err
|
||||
}
|
||||
|
||||
// StoreThumbnail inserts the metadata about the thumbnail into the database.
|
||||
// Returns an error if the combination of MediaID and Origin are not unique in the table.
|
||||
func (d *Database) StoreThumbnail(
|
||||
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
|
||||
) error {
|
||||
return d.statements.thumbnail.insertThumbnail(ctx, thumbnailMetadata)
|
||||
}
|
||||
|
||||
// GetThumbnail returns metadata about a specific thumbnail.
|
||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
||||
// Returns nil metadata if there is no metadata associated with this thumbnail.
|
||||
func (d *Database) GetThumbnail(
|
||||
ctx context.Context,
|
||||
mediaID types.MediaID,
|
||||
mediaOrigin gomatrixserverlib.ServerName,
|
||||
width, height int,
|
||||
resizeMethod string,
|
||||
) (*types.ThumbnailMetadata, error) {
|
||||
thumbnailMetadata, err := d.statements.thumbnail.selectThumbnail(
|
||||
ctx, mediaID, mediaOrigin, width, height, resizeMethod,
|
||||
)
|
||||
if err != nil && err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return thumbnailMetadata, err
|
||||
}
|
||||
|
||||
// GetThumbnails returns metadata about all thumbnails for a specific media stored on this server.
|
||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
||||
// Returns nil metadata if there are no thumbnails associated with this media.
|
||||
func (d *Database) GetThumbnails(
|
||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) ([]*types.ThumbnailMetadata, error) {
|
||||
thumbnails, err := d.statements.thumbnail.selectThumbnails(ctx, mediaID, mediaOrigin)
|
||||
if err != nil && err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return thumbnails, err
|
||||
}
|
||||
|
|
@ -22,6 +22,7 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
|
@ -54,55 +55,48 @@ SELECT content_type, file_size_bytes, creation_ts FROM mediaapi_thumbnail WHERE
|
|||
|
||||
// Note: this selects all thumbnails for a media_origin and media_id
|
||||
const selectThumbnailsSQL = `
|
||||
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2
|
||||
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 ORDER BY creation_ts ASC
|
||||
`
|
||||
|
||||
type thumbnailStatements struct {
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
insertThumbnailStmt *sql.Stmt
|
||||
selectThumbnailStmt *sql.Stmt
|
||||
selectThumbnailsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *thumbnailStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
||||
_, err = db.Exec(thumbnailSchema)
|
||||
func NewSQLiteThumbnailsTable(db *sql.DB) (tables.Thumbnails, error) {
|
||||
s := &thumbnailStatements{}
|
||||
_, err := db.Exec(thumbnailSchema)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
s.db = db
|
||||
s.writer = writer
|
||||
|
||||
return statementList{
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertThumbnailStmt, insertThumbnailSQL},
|
||||
{&s.selectThumbnailStmt, selectThumbnailSQL},
|
||||
{&s.selectThumbnailsStmt, selectThumbnailsSQL},
|
||||
}.prepare(db)
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *thumbnailStatements) insertThumbnail(
|
||||
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
|
||||
) error {
|
||||
thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
|
||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertThumbnailStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
thumbnailMetadata.MediaMetadata.MediaID,
|
||||
thumbnailMetadata.MediaMetadata.Origin,
|
||||
thumbnailMetadata.MediaMetadata.ContentType,
|
||||
thumbnailMetadata.MediaMetadata.FileSizeBytes,
|
||||
thumbnailMetadata.MediaMetadata.CreationTimestamp,
|
||||
thumbnailMetadata.ThumbnailSize.Width,
|
||||
thumbnailMetadata.ThumbnailSize.Height,
|
||||
thumbnailMetadata.ThumbnailSize.ResizeMethod,
|
||||
)
|
||||
return err
|
||||
})
|
||||
func (s *thumbnailStatements) InsertThumbnail(ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata) error {
|
||||
thumbnailMetadata.MediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
|
||||
_, err := sqlutil.TxStmtContext(ctx, txn, s.insertThumbnailStmt).ExecContext(
|
||||
ctx,
|
||||
thumbnailMetadata.MediaMetadata.MediaID,
|
||||
thumbnailMetadata.MediaMetadata.Origin,
|
||||
thumbnailMetadata.MediaMetadata.ContentType,
|
||||
thumbnailMetadata.MediaMetadata.FileSizeBytes,
|
||||
thumbnailMetadata.MediaMetadata.CreationTimestamp,
|
||||
thumbnailMetadata.ThumbnailSize.Width,
|
||||
thumbnailMetadata.ThumbnailSize.Height,
|
||||
thumbnailMetadata.ThumbnailSize.ResizeMethod,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *thumbnailStatements) selectThumbnail(
|
||||
func (s *thumbnailStatements) SelectThumbnail(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
mediaID types.MediaID,
|
||||
mediaOrigin gomatrixserverlib.ServerName,
|
||||
width, height int,
|
||||
|
|
@ -119,7 +113,7 @@ func (s *thumbnailStatements) selectThumbnail(
|
|||
ResizeMethod: resizeMethod,
|
||||
},
|
||||
}
|
||||
err := s.selectThumbnailStmt.QueryRowContext(
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailStmt).QueryRowContext(
|
||||
ctx,
|
||||
thumbnailMetadata.MediaMetadata.MediaID,
|
||||
thumbnailMetadata.MediaMetadata.Origin,
|
||||
|
|
@ -134,10 +128,11 @@ func (s *thumbnailStatements) selectThumbnail(
|
|||
return &thumbnailMetadata, err
|
||||
}
|
||||
|
||||
func (s *thumbnailStatements) selectThumbnails(
|
||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
func (s *thumbnailStatements) SelectThumbnails(
|
||||
ctx context.Context, txn *sql.Tx, mediaID types.MediaID,
|
||||
mediaOrigin gomatrixserverlib.ServerName,
|
||||
) ([]*types.ThumbnailMetadata, error) {
|
||||
rows, err := s.selectThumbnailsStmt.QueryContext(
|
||||
rows, err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailsStmt).QueryContext(
|
||||
ctx, mediaID, mediaOrigin,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -25,13 +25,13 @@ import (
|
|||
"github.com/matrix-org/dendrite/setup/config"
|
||||
)
|
||||
|
||||
// Open opens a postgres database.
|
||||
func Open(dbProperties *config.DatabaseOptions) (Database, error) {
|
||||
// NewMediaAPIDatasource opens a database connection.
|
||||
func NewMediaAPIDatasource(dbProperties *config.DatabaseOptions) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.Open(dbProperties)
|
||||
return sqlite3.NewDatabase(dbProperties)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
return postgres.Open(dbProperties)
|
||||
return postgres.NewDatabase(dbProperties)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected database type")
|
||||
}
|
||||
|
|
|
|||
135
mediaapi/storage/storage_test.go
Normal file
135
mediaapi/storage/storage_test.go
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
package storage_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/mediaapi/storage"
|
||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
)
|
||||
|
||||
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := storage.NewMediaAPIDatasource(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewSyncServerDatasource returned %s", err)
|
||||
}
|
||||
return db, close
|
||||
}
|
||||
func TestMediaRepository(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
ctx := context.Background()
|
||||
t.Run("can insert media & query media", func(t *testing.T) {
|
||||
metadata := &types.MediaMetadata{
|
||||
MediaID: "testing",
|
||||
Origin: "localhost",
|
||||
ContentType: "image/png",
|
||||
FileSizeBytes: 10,
|
||||
UploadName: "upload test",
|
||||
Base64Hash: "dGVzdGluZw==",
|
||||
UserID: "@alice:localhost",
|
||||
}
|
||||
if err := db.StoreMediaMetadata(ctx, metadata); err != nil {
|
||||
t.Fatalf("unable to store media metadata: %v", err)
|
||||
}
|
||||
// query by media id
|
||||
gotMetadata, err := db.GetMediaMetadata(ctx, metadata.MediaID, metadata.Origin)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to query media metadata: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(metadata, gotMetadata) {
|
||||
t.Fatalf("expected metadata %+v, got %v", metadata, gotMetadata)
|
||||
}
|
||||
// query by media hash
|
||||
gotMetadata, err = db.GetMediaMetadataByHash(ctx, metadata.Base64Hash, metadata.Origin)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to query media metadata by hash: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(metadata, gotMetadata) {
|
||||
t.Fatalf("expected metadata %+v, got %v", metadata, gotMetadata)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestThumbnailsStorage(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
ctx := context.Background()
|
||||
t.Run("can insert thumbnails & query media", func(t *testing.T) {
|
||||
thumbnails := []*types.ThumbnailMetadata{
|
||||
{
|
||||
MediaMetadata: &types.MediaMetadata{
|
||||
MediaID: "testing",
|
||||
Origin: "localhost",
|
||||
ContentType: "image/png",
|
||||
FileSizeBytes: 6,
|
||||
},
|
||||
ThumbnailSize: types.ThumbnailSize{
|
||||
Width: 5,
|
||||
Height: 5,
|
||||
ResizeMethod: types.Crop,
|
||||
},
|
||||
},
|
||||
{
|
||||
MediaMetadata: &types.MediaMetadata{
|
||||
MediaID: "testing",
|
||||
Origin: "localhost",
|
||||
ContentType: "image/png",
|
||||
FileSizeBytes: 7,
|
||||
},
|
||||
ThumbnailSize: types.ThumbnailSize{
|
||||
Width: 1,
|
||||
Height: 1,
|
||||
ResizeMethod: types.Scale,
|
||||
},
|
||||
},
|
||||
}
|
||||
for i := range thumbnails {
|
||||
if err := db.StoreThumbnail(ctx, thumbnails[i]); err != nil {
|
||||
t.Fatalf("unable to store thumbnail metadata: %v", err)
|
||||
}
|
||||
}
|
||||
// query by single thumbnail
|
||||
gotMetadata, err := db.GetThumbnail(ctx,
|
||||
thumbnails[0].MediaMetadata.MediaID,
|
||||
thumbnails[0].MediaMetadata.Origin,
|
||||
thumbnails[0].ThumbnailSize.Width, thumbnails[0].ThumbnailSize.Height,
|
||||
thumbnails[0].ThumbnailSize.ResizeMethod,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to query thumbnail metadata: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata) {
|
||||
t.Fatalf("expected metadata %+v, got %+v", thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata)
|
||||
}
|
||||
if !reflect.DeepEqual(thumbnails[0].ThumbnailSize, gotMetadata.ThumbnailSize) {
|
||||
t.Fatalf("expected metadata %+v, got %+v", thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata)
|
||||
}
|
||||
// query by all thumbnails
|
||||
gotMediadatas, err := db.GetThumbnails(ctx, thumbnails[0].MediaMetadata.MediaID, thumbnails[0].MediaMetadata.Origin)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to query media metadata by hash: %v", err)
|
||||
}
|
||||
if len(gotMediadatas) != len(thumbnails) {
|
||||
t.Fatalf("expected %d stored thumbnail metadata, got %d", len(thumbnails), len(gotMediadatas))
|
||||
}
|
||||
for i := range gotMediadatas {
|
||||
if !reflect.DeepEqual(thumbnails[i].MediaMetadata, gotMediadatas[i].MediaMetadata) {
|
||||
t.Fatalf("expected metadata %+v, got %v", thumbnails[i].MediaMetadata, gotMediadatas[i].MediaMetadata)
|
||||
}
|
||||
if !reflect.DeepEqual(thumbnails[i].ThumbnailSize, gotMediadatas[i].ThumbnailSize) {
|
||||
t.Fatalf("expected metadata %+v, got %v", thumbnails[i].ThumbnailSize, gotMediadatas[i].ThumbnailSize)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
|
@ -22,10 +22,10 @@ import (
|
|||
)
|
||||
|
||||
// Open opens a postgres database.
|
||||
func Open(dbProperties *config.DatabaseOptions) (Database, error) {
|
||||
func NewMediaAPIDatasource(dbProperties *config.DatabaseOptions) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.Open(dbProperties)
|
||||
return sqlite3.NewDatabase(dbProperties)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
return nil, fmt.Errorf("can't use Postgres implementation")
|
||||
default:
|
||||
|
|
|
|||
46
mediaapi/storage/tables/interface.go
Normal file
46
mediaapi/storage/tables/interface.go
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type Thumbnails interface {
|
||||
InsertThumbnail(ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata) error
|
||||
SelectThumbnail(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
width, height int,
|
||||
resizeMethod string,
|
||||
) (*types.ThumbnailMetadata, error)
|
||||
SelectThumbnails(
|
||||
ctx context.Context, txn *sql.Tx, mediaID types.MediaID,
|
||||
mediaOrigin gomatrixserverlib.ServerName,
|
||||
) ([]*types.ThumbnailMetadata, error)
|
||||
}
|
||||
|
||||
type MediaRepository interface {
|
||||
InsertMedia(ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata) error
|
||||
SelectMedia(ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error)
|
||||
SelectMediaByHash(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) (*types.MediaMetadata, error)
|
||||
}
|
||||
|
|
@ -45,16 +45,13 @@ type RequestMethod string
|
|||
// MatrixUserID is a Matrix user ID string in the form @user:domain e.g. @alice:matrix.org
|
||||
type MatrixUserID string
|
||||
|
||||
// UnixMs is the milliseconds since the Unix epoch
|
||||
type UnixMs int64
|
||||
|
||||
// MediaMetadata is metadata associated with a media file
|
||||
type MediaMetadata struct {
|
||||
MediaID MediaID
|
||||
Origin gomatrixserverlib.ServerName
|
||||
ContentType ContentType
|
||||
FileSizeBytes FileSizeBytes
|
||||
CreationTimestamp UnixMs
|
||||
CreationTimestamp gomatrixserverlib.Timestamp
|
||||
UploadName Filename
|
||||
Base64Hash Base64Hash
|
||||
UserID MatrixUserID
|
||||
|
|
|
|||
|
|
@ -167,6 +167,7 @@ func (r *Inputer) startWorkerForRoom(roomID string) {
|
|||
// will look to see if we have a worker for that room which has its
|
||||
// own consumer. If we don't, we'll start one.
|
||||
func (r *Inputer) Start() error {
|
||||
prometheus.MustRegister(roomserverInputBackpressure, processRoomEventDuration)
|
||||
_, err := r.JetStream.Subscribe(
|
||||
"", // This is blank because we specified it in BindStream.
|
||||
func(m *nats.Msg) {
|
||||
|
|
@ -421,10 +422,6 @@ func (r *Inputer) WriteOutputEvents(roomID string, updates []api.OutputEvent) er
|
|||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
prometheus.MustRegister(roomserverInputBackpressure)
|
||||
}
|
||||
|
||||
var roomserverInputBackpressure = prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: "dendrite",
|
||||
|
|
|
|||
|
|
@ -37,10 +37,6 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func init() {
|
||||
prometheus.MustRegister(processRoomEventDuration)
|
||||
}
|
||||
|
||||
// TODO: Does this value make sense?
|
||||
const MaximumMissingProcessingTime = time.Minute * 2
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package input_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
|
@ -12,30 +11,22 @@ import (
|
|||
"github.com/matrix-org/dendrite/roomserver/internal/input"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/nats-io/nats.go"
|
||||
)
|
||||
|
||||
func psqlConnectionString() config.DataSource {
|
||||
user := os.Getenv("POSTGRES_USER")
|
||||
if user == "" {
|
||||
user = "dendrite"
|
||||
}
|
||||
dbName := os.Getenv("POSTGRES_DB")
|
||||
if dbName == "" {
|
||||
dbName = "dendrite"
|
||||
}
|
||||
connStr := fmt.Sprintf(
|
||||
"user=%s dbname=%s sslmode=disable", user, dbName,
|
||||
)
|
||||
password := os.Getenv("POSTGRES_PASSWORD")
|
||||
if password != "" {
|
||||
connStr += fmt.Sprintf(" password=%s", password)
|
||||
}
|
||||
host := os.Getenv("POSTGRES_HOST")
|
||||
if host != "" {
|
||||
connStr += fmt.Sprintf(" host=%s", host)
|
||||
}
|
||||
return config.DataSource(connStr)
|
||||
var js nats.JetStreamContext
|
||||
var jc *nats.Conn
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
var pc *process.ProcessContext
|
||||
pc, js, jc = jetstream.PrepareForTests()
|
||||
code := m.Run()
|
||||
pc.ShutdownDendrite()
|
||||
pc.WaitForComponentsToFinish()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestSingleTransactionOnInput(t *testing.T) {
|
||||
|
|
@ -63,7 +54,7 @@ func TestSingleTransactionOnInput(t *testing.T) {
|
|||
}
|
||||
db, err := storage.Open(
|
||||
&config.DatabaseOptions{
|
||||
ConnectionString: psqlConnectionString(),
|
||||
ConnectionString: "",
|
||||
MaxOpenConnections: 1,
|
||||
MaxIdleConnections: 1,
|
||||
},
|
||||
|
|
@ -74,7 +65,9 @@ func TestSingleTransactionOnInput(t *testing.T) {
|
|||
t.SkipNow()
|
||||
}
|
||||
inputter := &input.Inputer{
|
||||
DB: db,
|
||||
DB: db,
|
||||
JetStream: js,
|
||||
NATSClient: jc,
|
||||
}
|
||||
res := &api.InputRoomEventsResponse{}
|
||||
inputter.InputRoomEvents(
|
||||
|
|
|
|||
|
|
@ -13,12 +13,22 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
|
||||
natsserver "github.com/nats-io/nats-server/v2/server"
|
||||
"github.com/nats-io/nats.go"
|
||||
natsclient "github.com/nats-io/nats.go"
|
||||
)
|
||||
|
||||
var natsServer *natsserver.Server
|
||||
var natsServerMutex sync.Mutex
|
||||
|
||||
func PrepareForTests() (*process.ProcessContext, nats.JetStreamContext, *nats.Conn) {
|
||||
cfg := &config.Dendrite{}
|
||||
cfg.Defaults(true)
|
||||
cfg.Global.JetStream.InMemory = true
|
||||
pc := process.NewProcessContext()
|
||||
js, jc := Prepare(pc, &cfg.Global.JetStream)
|
||||
return pc, js, jc
|
||||
}
|
||||
|
||||
func Prepare(process *process.ProcessContext, cfg *config.JetStream) (natsclient.JetStreamContext, *natsclient.Conn) {
|
||||
// check if we need an in-process NATS Server
|
||||
if len(cfg.Addresses) != 0 {
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ import (
|
|||
type Notifier struct {
|
||||
lock *sync.RWMutex
|
||||
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
|
||||
roomIDToJoinedUsers map[string]userIDSet
|
||||
roomIDToJoinedUsers map[string]*userIDSet
|
||||
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
|
||||
roomIDToPeekingDevices map[string]peekingDeviceSet
|
||||
// The latest sync position
|
||||
|
|
@ -54,7 +54,7 @@ type Notifier struct {
|
|||
// the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase).
|
||||
func NewNotifier() *Notifier {
|
||||
return &Notifier{
|
||||
roomIDToJoinedUsers: make(map[string]userIDSet),
|
||||
roomIDToJoinedUsers: make(map[string]*userIDSet),
|
||||
roomIDToPeekingDevices: make(map[string]peekingDeviceSet),
|
||||
userDeviceStreams: make(map[string]map[string]*UserDeviceStream),
|
||||
lock: &sync.RWMutex{},
|
||||
|
|
@ -262,7 +262,7 @@ func (n *Notifier) SharedUsers(userID string) []string {
|
|||
func (n *Notifier) _sharedUsers(userID string) []string {
|
||||
n._sharedUserMap[userID] = struct{}{}
|
||||
for roomID, users := range n.roomIDToJoinedUsers {
|
||||
if _, ok := users[userID]; !ok {
|
||||
if ok := users.isIn(userID); !ok {
|
||||
continue
|
||||
}
|
||||
for _, userID := range n._joinedUsers(roomID) {
|
||||
|
|
@ -282,8 +282,11 @@ func (n *Notifier) IsSharedUser(userA, userB string) bool {
|
|||
defer n.lock.RUnlock()
|
||||
var okA, okB bool
|
||||
for _, users := range n.roomIDToJoinedUsers {
|
||||
_, okA = users[userA]
|
||||
_, okB = users[userB]
|
||||
okA = users.isIn(userA)
|
||||
if !okA {
|
||||
continue
|
||||
}
|
||||
okB = users.isIn(userB)
|
||||
if okA && okB {
|
||||
return true
|
||||
}
|
||||
|
|
@ -345,11 +348,12 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) {
|
|||
// This is just the bulk form of addJoinedUser
|
||||
for roomID, userIDs := range roomIDToUserIDs {
|
||||
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
|
||||
n.roomIDToJoinedUsers[roomID] = make(userIDSet, len(userIDs))
|
||||
n.roomIDToJoinedUsers[roomID] = newUserIDSet(len(userIDs))
|
||||
}
|
||||
for _, userID := range userIDs {
|
||||
n.roomIDToJoinedUsers[roomID].add(userID)
|
||||
}
|
||||
n.roomIDToJoinedUsers[roomID].precompute()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -440,16 +444,18 @@ func (n *Notifier) _fetchUserStreams(userID string) []*UserDeviceStream {
|
|||
|
||||
func (n *Notifier) _addJoinedUser(roomID, userID string) {
|
||||
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
|
||||
n.roomIDToJoinedUsers[roomID] = make(userIDSet)
|
||||
n.roomIDToJoinedUsers[roomID] = newUserIDSet(8)
|
||||
}
|
||||
n.roomIDToJoinedUsers[roomID].add(userID)
|
||||
n.roomIDToJoinedUsers[roomID].precompute()
|
||||
}
|
||||
|
||||
func (n *Notifier) _removeJoinedUser(roomID, userID string) {
|
||||
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
|
||||
n.roomIDToJoinedUsers[roomID] = make(userIDSet)
|
||||
n.roomIDToJoinedUsers[roomID] = newUserIDSet(8)
|
||||
}
|
||||
n.roomIDToJoinedUsers[roomID].remove(userID)
|
||||
n.roomIDToJoinedUsers[roomID].precompute()
|
||||
}
|
||||
|
||||
func (n *Notifier) JoinedUsers(roomID string) (userIDs []string) {
|
||||
|
|
@ -521,19 +527,52 @@ func (n *Notifier) _removeEmptyUserStreams() {
|
|||
}
|
||||
|
||||
// A string set, mainly existing for improving clarity of structs in this file.
|
||||
type userIDSet map[string]struct{}
|
||||
|
||||
func (s userIDSet) add(str string) {
|
||||
s[str] = struct{}{}
|
||||
type userIDSet struct {
|
||||
sync.Mutex
|
||||
set map[string]struct{}
|
||||
precomputed []string
|
||||
}
|
||||
|
||||
func (s userIDSet) remove(str string) {
|
||||
delete(s, str)
|
||||
func newUserIDSet(cap int) *userIDSet {
|
||||
return &userIDSet{
|
||||
set: make(map[string]struct{}, cap),
|
||||
precomputed: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func (s userIDSet) values() (vals []string) {
|
||||
vals = make([]string, 0, len(s))
|
||||
for str := range s {
|
||||
func (s *userIDSet) add(str string) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
s.set[str] = struct{}{}
|
||||
s.precomputed = s.precomputed[:0] // invalidate cache
|
||||
}
|
||||
|
||||
func (s *userIDSet) remove(str string) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
delete(s.set, str)
|
||||
s.precomputed = s.precomputed[:0] // invalidate cache
|
||||
}
|
||||
|
||||
func (s *userIDSet) precompute() {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
s.precomputed = s.values()
|
||||
}
|
||||
|
||||
func (s *userIDSet) isIn(str string) bool {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
_, ok := s.set[str]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *userIDSet) values() (vals []string) {
|
||||
if len(s.precomputed) > 0 {
|
||||
return s.precomputed // only return if not invalidated
|
||||
}
|
||||
vals = make([]string, 0, len(s.set))
|
||||
for str := range s.set {
|
||||
vals = append(vals, str)
|
||||
}
|
||||
return
|
||||
|
|
|
|||
|
|
@ -165,9 +165,9 @@ func TestCorrectStreamWakeup(t *testing.T) {
|
|||
|
||||
go func() {
|
||||
select {
|
||||
case <-streamone.signalChannel:
|
||||
case <-streamone.ch():
|
||||
awoken <- "one"
|
||||
case <-streamtwo.signalChannel:
|
||||
case <-streamtwo.ch():
|
||||
awoken <- "two"
|
||||
}
|
||||
}()
|
||||
|
|
|
|||
|
|
@ -118,6 +118,12 @@ func (s *UserDeviceStream) TimeOfLastNonEmpty() time.Time {
|
|||
return s.timeOfLastChannel
|
||||
}
|
||||
|
||||
func (s *UserDeviceStream) ch() <-chan struct{} {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
return s.signalChannel
|
||||
}
|
||||
|
||||
// GetSyncPosition returns last sync position which the UserStream was
|
||||
// notified about
|
||||
func (s *UserDeviceStreamListener) GetSyncPosition() types.StreamingToken {
|
||||
|
|
|
|||
|
|
@ -60,7 +60,9 @@ func Context(
|
|||
Headers: nil,
|
||||
}
|
||||
}
|
||||
filter.Rooms = append(filter.Rooms, roomID)
|
||||
if filter.Rooms != nil {
|
||||
*filter.Rooms = append(*filter.Rooms, roomID)
|
||||
}
|
||||
|
||||
ctx := req.Context()
|
||||
membershipRes := roomserver.QueryMembershipForUserResponse{}
|
||||
|
|
|
|||
|
|
@ -262,12 +262,8 @@ func (r *messagesReq) retrieveEvents() (
|
|||
clientEvents []gomatrixserverlib.ClientEvent, start,
|
||||
end types.TopologyToken, err error,
|
||||
) {
|
||||
eventFilter := r.filter
|
||||
|
||||
// Retrieve the events from the local database.
|
||||
streamEvents, err := r.db.GetEventsInTopologicalRange(
|
||||
r.ctx, r.from, r.to, r.roomID, eventFilter.Limit, r.backwardOrdering,
|
||||
)
|
||||
streamEvents, err := r.db.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("GetEventsInRange: %w", err)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -104,8 +104,8 @@ type Database interface {
|
|||
// DeletePeek deletes all peeks for a given room by a given user
|
||||
// Returns an error if there was a problem communicating with the database.
|
||||
DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error)
|
||||
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit.
|
||||
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error)
|
||||
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last.
|
||||
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error)
|
||||
// EventPositionInTopology returns the depth and stream position of the given event.
|
||||
EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
|
||||
// BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events.
|
||||
|
|
|
|||
|
|
@ -47,14 +47,10 @@ const selectBackwardExtremitiesForRoomSQL = "" +
|
|||
const deleteBackwardExtremitySQL = "" +
|
||||
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
|
||||
|
||||
const deleteBackwardExtremitiesForRoomSQL = "" +
|
||||
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1"
|
||||
|
||||
type backwardExtremitiesStatements struct {
|
||||
insertBackwardExtremityStmt *sql.Stmt
|
||||
selectBackwardExtremitiesForRoomStmt *sql.Stmt
|
||||
deleteBackwardExtremityStmt *sql.Stmt
|
||||
deleteBackwardExtremitiesForRoomStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
|
||||
|
|
@ -72,9 +68,6 @@ func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremiti
|
|||
if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
|
|
@ -113,10 +106,3 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
|
|||
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID)
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -233,9 +233,10 @@ func (s *currentRoomStateStatements) SelectCurrentState(
|
|||
excludeEventIDs []string,
|
||||
) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt)
|
||||
senders, notSenders := getSendersStateFilterFilter(stateFilter)
|
||||
rows, err := stmt.QueryContext(ctx, roomID,
|
||||
pq.StringArray(stateFilter.Senders),
|
||||
pq.StringArray(stateFilter.NotSenders),
|
||||
pq.StringArray(senders),
|
||||
pq.StringArray(notSenders),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)),
|
||||
stateFilter.ContainsURL,
|
||||
|
|
|
|||
|
|
@ -16,21 +16,45 @@ package postgres
|
|||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// filterConvertWildcardToSQL converts wildcards as defined in
|
||||
// https://matrix.org/docs/spec/client_server/r0.3.0.html#post-matrix-client-r0-user-userid-filter
|
||||
// to SQL wildcards that can be used with LIKE()
|
||||
func filterConvertTypeWildcardToSQL(values []string) []string {
|
||||
func filterConvertTypeWildcardToSQL(values *[]string) []string {
|
||||
if values == nil {
|
||||
// Return nil instead of []string{} so IS NULL can work correctly when
|
||||
// the return value is passed into SQL queries
|
||||
return nil
|
||||
}
|
||||
|
||||
ret := make([]string, len(values))
|
||||
for i := range values {
|
||||
ret[i] = strings.Replace(values[i], "*", "%", -1)
|
||||
v := *values
|
||||
ret := make([]string, len(v))
|
||||
for i := range v {
|
||||
ret[i] = strings.Replace(v[i], "*", "%", -1)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// TODO: Replace when Dendrite uses Go 1.18
|
||||
func getSendersRoomEventFilter(filter *gomatrixserverlib.RoomEventFilter) (senders []string, notSenders []string) {
|
||||
if filter.Senders != nil {
|
||||
senders = *filter.Senders
|
||||
}
|
||||
if filter.NotSenders != nil {
|
||||
notSenders = *filter.NotSenders
|
||||
}
|
||||
return senders, notSenders
|
||||
}
|
||||
|
||||
func getSendersStateFilterFilter(filter *gomatrixserverlib.StateFilter) (senders []string, notSenders []string) {
|
||||
if filter.Senders != nil {
|
||||
senders = *filter.Senders
|
||||
}
|
||||
if filter.NotSenders != nil {
|
||||
notSenders = *filter.NotSenders
|
||||
}
|
||||
return senders, notSenders
|
||||
}
|
||||
|
|
|
|||
|
|
@ -56,12 +56,6 @@ const upsertMembershipSQL = "" +
|
|||
" ON CONFLICT ON CONSTRAINT syncapi_memberships_unique" +
|
||||
" DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6"
|
||||
|
||||
const selectMembershipSQL = "" +
|
||||
"SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" +
|
||||
" WHERE room_id = $1 AND user_id = $2 AND membership = ANY($3)" +
|
||||
" ORDER BY stream_pos DESC" +
|
||||
" LIMIT 1"
|
||||
|
||||
const selectMembershipCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM (" +
|
||||
" SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" +
|
||||
|
|
@ -69,7 +63,6 @@ const selectMembershipCountSQL = "" +
|
|||
|
||||
type membershipsStatements struct {
|
||||
upsertMembershipStmt *sql.Stmt
|
||||
selectMembershipStmt *sql.Stmt
|
||||
selectMembershipCountStmt *sql.Stmt
|
||||
}
|
||||
|
||||
|
|
@ -82,9 +75,6 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
|
|||
if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectMembershipStmt, err = db.Prepare(selectMembershipSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectMembershipCountStmt, err = db.Prepare(selectMembershipCountSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -111,14 +101,6 @@ func (s *membershipsStatements) UpsertMembership(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *membershipsStatements) SelectMembership(
|
||||
ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string,
|
||||
) (eventID string, streamPos, topologyPos types.StreamPosition, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectMembershipStmt)
|
||||
err = stmt.QueryRowContext(ctx, roomID, userID, memberships).Scan(&eventID, &streamPos, &topologyPos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *membershipsStatements) SelectMembershipCount(
|
||||
ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition,
|
||||
) (count int, err error) {
|
||||
|
|
|
|||
|
|
@ -81,6 +81,15 @@ const insertEventSQL = "" +
|
|||
const selectEventsSQL = "" +
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)"
|
||||
|
||||
const selectEventsWithFilterSQL = "" +
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" +
|
||||
" AND ( $2::text[] IS NULL OR sender = ANY($2) )" +
|
||||
" AND ( $3::text[] IS NULL OR NOT(sender = ANY($3)) )" +
|
||||
" AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" +
|
||||
" AND ( $5::text[] IS NULL OR NOT(type LIKE ANY($5)) )" +
|
||||
" AND ( $6::bool IS NULL OR contains_url = $6 )" +
|
||||
" LIMIT $7"
|
||||
|
||||
const selectRecentEventsSQL = "" +
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
|
||||
" WHERE room_id = $1 AND id > $2 AND id <= $3" +
|
||||
|
|
@ -153,6 +162,7 @@ const selectContextAfterEventSQL = "" +
|
|||
type outputRoomEventsStatements struct {
|
||||
insertEventStmt *sql.Stmt
|
||||
selectEventsStmt *sql.Stmt
|
||||
selectEventsWitFilterStmt *sql.Stmt
|
||||
selectMaxEventIDStmt *sql.Stmt
|
||||
selectRecentEventsStmt *sql.Stmt
|
||||
selectRecentEventsForSyncStmt *sql.Stmt
|
||||
|
|
@ -174,6 +184,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
|
|||
return s, sqlutil.StatementList{
|
||||
{&s.insertEventStmt, insertEventSQL},
|
||||
{&s.selectEventsStmt, selectEventsSQL},
|
||||
{&s.selectEventsWitFilterStmt, selectEventsWithFilterSQL},
|
||||
{&s.selectMaxEventIDStmt, selectMaxEventIDSQL},
|
||||
{&s.selectRecentEventsStmt, selectRecentEventsSQL},
|
||||
{&s.selectRecentEventsForSyncStmt, selectRecentEventsForSyncSQL},
|
||||
|
|
@ -204,11 +215,11 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
|
|||
stateFilter *gomatrixserverlib.StateFilter, roomIDs []string,
|
||||
) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt)
|
||||
|
||||
senders, notSenders := getSendersStateFilterFilter(stateFilter)
|
||||
rows, err := stmt.QueryContext(
|
||||
ctx, r.Low(), r.High(), pq.StringArray(roomIDs),
|
||||
pq.StringArray(stateFilter.Senders),
|
||||
pq.StringArray(stateFilter.NotSenders),
|
||||
pq.StringArray(senders),
|
||||
pq.StringArray(notSenders),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)),
|
||||
stateFilter.ContainsURL,
|
||||
|
|
@ -310,7 +321,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
|
|||
// Parse content as JSON and search for an "url" key
|
||||
containsURL := false
|
||||
var content map[string]interface{}
|
||||
if json.Unmarshal(event.Content(), &content) != nil {
|
||||
if json.Unmarshal(event.Content(), &content) == nil {
|
||||
// Set containsURL to true if url is present
|
||||
_, containsURL = content["url"]
|
||||
}
|
||||
|
|
@ -353,10 +364,11 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
|
|||
} else {
|
||||
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt)
|
||||
}
|
||||
senders, notSenders := getSendersRoomEventFilter(eventFilter)
|
||||
rows, err := stmt.QueryContext(
|
||||
ctx, roomID, r.Low(), r.High(),
|
||||
pq.StringArray(eventFilter.Senders),
|
||||
pq.StringArray(eventFilter.NotSenders),
|
||||
pq.StringArray(senders),
|
||||
pq.StringArray(notSenders),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)),
|
||||
eventFilter.Limit+1,
|
||||
|
|
@ -398,11 +410,12 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
|
|||
ctx context.Context, txn *sql.Tx,
|
||||
roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
|
||||
) ([]types.StreamEvent, error) {
|
||||
senders, notSenders := getSendersRoomEventFilter(eventFilter)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectEarlyEventsStmt)
|
||||
rows, err := stmt.QueryContext(
|
||||
ctx, roomID, r.Low(), r.High(),
|
||||
pq.StringArray(eventFilter.Senders),
|
||||
pq.StringArray(eventFilter.NotSenders),
|
||||
pq.StringArray(senders),
|
||||
pq.StringArray(notSenders),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)),
|
||||
eventFilter.Limit,
|
||||
|
|
@ -427,15 +440,52 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
|
|||
// selectEvents returns the events for the given event IDs. If an event is
|
||||
// missing from the database, it will be omitted.
|
||||
func (s *outputRoomEventsStatements) SelectEvents(
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool,
|
||||
) ([]types.StreamEvent, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectEventsStmt)
|
||||
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||
var (
|
||||
stmt *sql.Stmt
|
||||
rows *sql.Rows
|
||||
err error
|
||||
)
|
||||
if filter == nil {
|
||||
stmt = sqlutil.TxStmt(txn, s.selectEventsStmt)
|
||||
rows, err = stmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||
} else {
|
||||
senders, notSenders := getSendersRoomEventFilter(filter)
|
||||
stmt = sqlutil.TxStmt(txn, s.selectEventsWitFilterStmt)
|
||||
rows, err = stmt.QueryContext(ctx,
|
||||
pq.StringArray(eventIDs),
|
||||
pq.StringArray(senders),
|
||||
pq.StringArray(notSenders),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
|
||||
filter.ContainsURL,
|
||||
filter.Limit,
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
|
||||
return rowsToStreamEvents(rows)
|
||||
streamEvents, err := rowsToStreamEvents(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if preserveOrder {
|
||||
eventMap := make(map[string]types.StreamEvent)
|
||||
for _, ev := range streamEvents {
|
||||
eventMap[ev.EventID()] = ev
|
||||
}
|
||||
var returnEvents []types.StreamEvent
|
||||
for _, eventID := range eventIDs {
|
||||
ev, ok := eventMap[eventID]
|
||||
if ok {
|
||||
returnEvents = append(returnEvents, ev)
|
||||
}
|
||||
}
|
||||
return returnEvents, nil
|
||||
}
|
||||
return streamEvents, nil
|
||||
}
|
||||
|
||||
func (s *outputRoomEventsStatements) DeleteEventsForRoom(
|
||||
|
|
@ -462,10 +512,11 @@ func (s *outputRoomEventsStatements) SelectContextEvent(ctx context.Context, txn
|
|||
func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
|
||||
ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter,
|
||||
) (evts []*gomatrixserverlib.HeaderedEvent, err error) {
|
||||
senders, notSenders := getSendersRoomEventFilter(filter)
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectContextBeforeEventStmt).QueryContext(
|
||||
ctx, roomID, id, filter.Limit,
|
||||
pq.StringArray(filter.Senders),
|
||||
pq.StringArray(filter.NotSenders),
|
||||
pq.StringArray(senders),
|
||||
pq.StringArray(notSenders),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
|
||||
)
|
||||
|
|
@ -494,10 +545,11 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
|
|||
func (s *outputRoomEventsStatements) SelectContextAfterEvent(
|
||||
ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter,
|
||||
) (lastID int, evts []*gomatrixserverlib.HeaderedEvent, err error) {
|
||||
senders, notSenders := getSendersRoomEventFilter(filter)
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectContextAfterEventStmt).QueryContext(
|
||||
ctx, roomID, id, filter.Limit,
|
||||
pq.StringArray(filter.Senders),
|
||||
pq.StringArray(filter.NotSenders),
|
||||
pq.StringArray(senders),
|
||||
pq.StringArray(notSenders),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -73,9 +73,6 @@ const selectMaxPositionInTopologySQL = "" +
|
|||
"SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" +
|
||||
") ORDER BY stream_position DESC LIMIT 1"
|
||||
|
||||
const deleteTopologyForRoomSQL = "" +
|
||||
"DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1"
|
||||
|
||||
const selectStreamToTopologicalPositionAscSQL = "" +
|
||||
"SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;"
|
||||
|
||||
|
|
@ -88,7 +85,6 @@ type outputRoomEventsTopologyStatements struct {
|
|||
selectEventIDsInRangeDESCStmt *sql.Stmt
|
||||
selectPositionInTopologyStmt *sql.Stmt
|
||||
selectMaxPositionInTopologyStmt *sql.Stmt
|
||||
deleteTopologyForRoomStmt *sql.Stmt
|
||||
selectStreamToTopologicalPositionAscStmt *sql.Stmt
|
||||
selectStreamToTopologicalPositionDescStmt *sql.Stmt
|
||||
}
|
||||
|
|
@ -114,9 +110,6 @@ func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) {
|
|||
if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteTopologyForRoomStmt, err = db.Prepare(deleteTopologyForRoomSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -148,9 +141,9 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
|
|||
// is requested or not.
|
||||
var stmt *sql.Stmt
|
||||
if chronologicalOrder {
|
||||
stmt = s.selectEventIDsInRangeASCStmt
|
||||
stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeASCStmt)
|
||||
} else {
|
||||
stmt = s.selectEventIDsInRangeDESCStmt
|
||||
stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeDESCStmt)
|
||||
}
|
||||
|
||||
// Query the event IDs.
|
||||
|
|
@ -203,10 +196,3 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
|
|||
err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID)
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -150,7 +150,7 @@ func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, stre
|
|||
// Returns an error if there was a problem talking with the database.
|
||||
// Does not include any transaction IDs in the returned events.
|
||||
func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||
streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs)
|
||||
streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, nil, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -312,7 +312,7 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e
|
|||
|
||||
// Check if we have all of the event's previous events. If an event is
|
||||
// missing, add it to the room's backward extremities.
|
||||
prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs())
|
||||
prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), nil, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -429,7 +429,8 @@ func (d *Database) updateRoomState(
|
|||
func (d *Database) GetEventsInTopologicalRange(
|
||||
ctx context.Context,
|
||||
from, to *types.TopologyToken,
|
||||
roomID string, limit int,
|
||||
roomID string,
|
||||
filter *gomatrixserverlib.RoomEventFilter,
|
||||
backwardOrdering bool,
|
||||
) (events []types.StreamEvent, err error) {
|
||||
var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition
|
||||
|
|
@ -450,14 +451,14 @@ func (d *Database) GetEventsInTopologicalRange(
|
|||
// Select the event IDs from the defined range.
|
||||
var eIDs []string
|
||||
eIDs, err = d.Topology.SelectEventIDsInRange(
|
||||
ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, limit, !backwardOrdering,
|
||||
ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Retrieve the events' contents using their IDs.
|
||||
events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs)
|
||||
events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs, filter, true)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -619,7 +620,7 @@ func (d *Database) fetchMissingStateEvents(
|
|||
) ([]types.StreamEvent, error) {
|
||||
// Fetch from the events table first so we pick up the stream ID for the
|
||||
// event.
|
||||
events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs)
|
||||
events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, nil, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -41,23 +41,23 @@ const insertAccountDataSQL = "" +
|
|||
" ON CONFLICT (user_id, room_id, type) DO UPDATE" +
|
||||
" SET id = $5"
|
||||
|
||||
// further parameters are added by prepareWithFilters
|
||||
const selectAccountDataInRangeSQL = "" +
|
||||
"SELECT room_id, type FROM syncapi_account_data_type" +
|
||||
" WHERE user_id = $1 AND id > $2 AND id <= $3" +
|
||||
" ORDER BY id ASC"
|
||||
" WHERE user_id = $1 AND id > $2 AND id <= $3"
|
||||
|
||||
const selectMaxAccountDataIDSQL = "" +
|
||||
"SELECT MAX(id) FROM syncapi_account_data_type"
|
||||
|
||||
type accountDataStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
insertAccountDataStmt *sql.Stmt
|
||||
selectMaxAccountDataIDStmt *sql.Stmt
|
||||
selectAccountDataInRangeStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) {
|
||||
func NewSqliteAccountDataTable(db *sql.DB, streamID *StreamIDStatements) (tables.AccountData, error) {
|
||||
s := &accountDataStatements{
|
||||
db: db,
|
||||
streamIDStatements: streamID,
|
||||
|
|
@ -94,18 +94,24 @@ func (s *accountDataStatements) SelectAccountDataInRange(
|
|||
ctx context.Context,
|
||||
userID string,
|
||||
r types.Range,
|
||||
accountDataFilterPart *gomatrixserverlib.EventFilter,
|
||||
filter *gomatrixserverlib.EventFilter,
|
||||
) (data map[string][]string, err error) {
|
||||
data = make(map[string][]string)
|
||||
stmt, params, err := prepareWithFilters(
|
||||
s.db, nil, selectAccountDataInRangeSQL,
|
||||
[]interface{}{
|
||||
userID, r.Low(), r.High(),
|
||||
},
|
||||
filter.Senders, filter.NotSenders,
|
||||
filter.Types, filter.NotTypes,
|
||||
[]string{}, nil, filter.Limit, FilterOrderAsc)
|
||||
|
||||
rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High())
|
||||
rows, err := stmt.QueryContext(ctx, params...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed")
|
||||
|
||||
var entries int
|
||||
|
||||
for rows.Next() {
|
||||
var dataType string
|
||||
var roomID string
|
||||
|
|
@ -114,31 +120,11 @@ func (s *accountDataStatements) SelectAccountDataInRange(
|
|||
return
|
||||
}
|
||||
|
||||
// check if we should add this by looking at the filter.
|
||||
// It would be nice if we could do this in SQL-land, but the mix of variadic
|
||||
// and positional parameters makes the query annoyingly hard to do, it's easier
|
||||
// and clearer to do it in Go-land. If there are no filters for [not]types then
|
||||
// this gets skipped.
|
||||
for _, includeType := range accountDataFilterPart.Types {
|
||||
if includeType != dataType { // TODO: wildcard support
|
||||
continue
|
||||
}
|
||||
}
|
||||
for _, excludeType := range accountDataFilterPart.NotTypes {
|
||||
if excludeType == dataType { // TODO: wildcard support
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if len(data[roomID]) > 0 {
|
||||
data[roomID] = append(data[roomID], dataType)
|
||||
} else {
|
||||
data[roomID] = []string{dataType}
|
||||
}
|
||||
entries++
|
||||
if entries >= accountDataFilterPart.Limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return data, nil
|
||||
|
|
|
|||
|
|
@ -47,15 +47,11 @@ const selectBackwardExtremitiesForRoomSQL = "" +
|
|||
const deleteBackwardExtremitySQL = "" +
|
||||
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
|
||||
|
||||
const deleteBackwardExtremitiesForRoomSQL = "" +
|
||||
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1"
|
||||
|
||||
type backwardExtremitiesStatements struct {
|
||||
db *sql.DB
|
||||
insertBackwardExtremityStmt *sql.Stmt
|
||||
selectBackwardExtremitiesForRoomStmt *sql.Stmt
|
||||
deleteBackwardExtremityStmt *sql.Stmt
|
||||
deleteBackwardExtremitiesForRoomStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
|
||||
|
|
@ -75,9 +71,6 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities
|
|||
if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
|
|
@ -116,10 +109,3 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
|
|||
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID)
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ const selectEventsWithEventIDsSQL = "" +
|
|||
|
||||
type currentRoomStateStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
upsertRoomStateStmt *sql.Stmt
|
||||
deleteRoomStateByEventIDStmt *sql.Stmt
|
||||
deleteRoomStateForRoomStmt *sql.Stmt
|
||||
|
|
@ -100,7 +100,7 @@ type currentRoomStateStatements struct {
|
|||
selectStateEventStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
|
||||
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
|
||||
s := ¤tRoomStateStatements{
|
||||
db: db,
|
||||
streamIDStatements: streamID,
|
||||
|
|
@ -220,7 +220,7 @@ func (s *currentRoomStateStatements) SelectCurrentState(
|
|||
},
|
||||
stateFilter.Senders, stateFilter.NotSenders,
|
||||
stateFilter.Types, stateFilter.NotTypes,
|
||||
excludeEventIDs, stateFilter.Limit, FilterOrderNone,
|
||||
excludeEventIDs, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderNone,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
|
||||
|
|
|
|||
|
|
@ -25,34 +25,53 @@ const (
|
|||
// parts.
|
||||
func prepareWithFilters(
|
||||
db *sql.DB, txn *sql.Tx, query string, params []interface{},
|
||||
senders, notsenders, types, nottypes []string, excludeEventIDs []string,
|
||||
limit int, order FilterOrder,
|
||||
senders, notsenders, types, nottypes *[]string, excludeEventIDs []string,
|
||||
containsURL *bool, limit int, order FilterOrder,
|
||||
) (*sql.Stmt, []interface{}, error) {
|
||||
offset := len(params)
|
||||
if count := len(senders); count > 0 {
|
||||
query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range senders {
|
||||
params, offset = append(params, v), offset+1
|
||||
if senders != nil {
|
||||
if count := len(*senders); count > 0 {
|
||||
query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range *senders {
|
||||
params, offset = append(params, v), offset+1
|
||||
}
|
||||
} else {
|
||||
query += ` AND sender = ""`
|
||||
}
|
||||
}
|
||||
if count := len(notsenders); count > 0 {
|
||||
query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range notsenders {
|
||||
params, offset = append(params, v), offset+1
|
||||
if notsenders != nil {
|
||||
if count := len(*notsenders); count > 0 {
|
||||
query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range *notsenders {
|
||||
params, offset = append(params, v), offset+1
|
||||
}
|
||||
} else {
|
||||
query += ` AND sender NOT = ""`
|
||||
}
|
||||
}
|
||||
if count := len(types); count > 0 {
|
||||
query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range types {
|
||||
params, offset = append(params, v), offset+1
|
||||
if types != nil {
|
||||
if count := len(*types); count > 0 {
|
||||
query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range *types {
|
||||
params, offset = append(params, v), offset+1
|
||||
}
|
||||
} else {
|
||||
query += ` AND type = ""`
|
||||
}
|
||||
}
|
||||
if count := len(nottypes); count > 0 {
|
||||
query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range nottypes {
|
||||
params, offset = append(params, v), offset+1
|
||||
if nottypes != nil {
|
||||
if count := len(*nottypes); count > 0 {
|
||||
query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range *nottypes {
|
||||
params, offset = append(params, v), offset+1
|
||||
}
|
||||
} else {
|
||||
query += ` AND type NOT = ""`
|
||||
}
|
||||
}
|
||||
if containsURL != nil {
|
||||
query += fmt.Sprintf(" AND contains_url = %v", *containsURL)
|
||||
}
|
||||
if count := len(excludeEventIDs); count > 0 {
|
||||
query += " AND event_id NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range excludeEventIDs {
|
||||
|
|
|
|||
|
|
@ -59,14 +59,14 @@ const selectMaxInviteIDSQL = "" +
|
|||
|
||||
type inviteEventsStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
insertInviteEventStmt *sql.Stmt
|
||||
selectInviteEventsInRangeStmt *sql.Stmt
|
||||
deleteInviteEventStmt *sql.Stmt
|
||||
selectMaxInviteIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) {
|
||||
func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Invites, error) {
|
||||
s := &inviteEventsStatements{
|
||||
db: db,
|
||||
streamIDStatements: streamID,
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
|
|
@ -57,12 +56,6 @@ const upsertMembershipSQL = "" +
|
|||
" ON CONFLICT (room_id, user_id, membership)" +
|
||||
" DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6"
|
||||
|
||||
const selectMembershipSQL = "" +
|
||||
"SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" +
|
||||
" WHERE room_id = $1 AND user_id = $2 AND membership IN ($3)" +
|
||||
" ORDER BY stream_pos DESC" +
|
||||
" LIMIT 1"
|
||||
|
||||
const selectMembershipCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM (" +
|
||||
" SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" +
|
||||
|
|
@ -111,22 +104,6 @@ func (s *membershipsStatements) UpsertMembership(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *membershipsStatements) SelectMembership(
|
||||
ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string,
|
||||
) (eventID string, streamPos, topologyPos types.StreamPosition, err error) {
|
||||
params := []interface{}{roomID, userID}
|
||||
for _, membership := range memberships {
|
||||
params = append(params, membership)
|
||||
}
|
||||
orig := strings.Replace(selectMembershipSQL, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1)
|
||||
stmt, err := s.db.Prepare(orig)
|
||||
if err != nil {
|
||||
return "", 0, 0, err
|
||||
}
|
||||
err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *membershipsStatements) SelectMembershipCount(
|
||||
ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition,
|
||||
) (count int, err error) {
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ const insertEventSQL = "" +
|
|||
"ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)"
|
||||
|
||||
const selectEventsSQL = "" +
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1"
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id IN ($1)"
|
||||
|
||||
const selectRecentEventsSQL = "" +
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
|
||||
|
|
@ -111,9 +111,8 @@ const selectContextAfterEventSQL = "" +
|
|||
|
||||
type outputRoomEventsStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
insertEventStmt *sql.Stmt
|
||||
selectEventsStmt *sql.Stmt
|
||||
selectMaxEventIDStmt *sql.Stmt
|
||||
updateEventJSONStmt *sql.Stmt
|
||||
deleteEventsForRoomStmt *sql.Stmt
|
||||
|
|
@ -122,7 +121,7 @@ type outputRoomEventsStatements struct {
|
|||
selectContextAfterEventStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
|
||||
func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Events, error) {
|
||||
s := &outputRoomEventsStatements{
|
||||
db: db,
|
||||
streamIDStatements: streamID,
|
||||
|
|
@ -133,7 +132,6 @@ func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Even
|
|||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertEventStmt, insertEventSQL},
|
||||
{&s.selectEventsStmt, selectEventsSQL},
|
||||
{&s.selectMaxEventIDStmt, selectMaxEventIDSQL},
|
||||
{&s.updateEventJSONStmt, updateEventJSONSQL},
|
||||
{&s.deleteEventsForRoomStmt, deleteEventsForRoomSQL},
|
||||
|
|
@ -170,7 +168,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
|
|||
s.db, txn, stmtSQL, inputParams,
|
||||
stateFilter.Senders, stateFilter.NotSenders,
|
||||
stateFilter.Types, stateFilter.NotTypes,
|
||||
nil, stateFilter.Limit, FilterOrderAsc,
|
||||
nil, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderAsc,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err)
|
||||
|
|
@ -279,7 +277,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
|
|||
// Parse content as JSON and search for an "url" key
|
||||
containsURL := false
|
||||
var content map[string]interface{}
|
||||
if json.Unmarshal(event.Content(), &content) != nil {
|
||||
if json.Unmarshal(event.Content(), &content) == nil {
|
||||
// Set containsURL to true if url is present
|
||||
_, containsURL = content["url"]
|
||||
}
|
||||
|
|
@ -347,7 +345,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
|
|||
},
|
||||
eventFilter.Senders, eventFilter.NotSenders,
|
||||
eventFilter.Types, eventFilter.NotTypes,
|
||||
nil, eventFilter.Limit+1, FilterOrderDesc,
|
||||
nil, eventFilter.ContainsURL, eventFilter.Limit+1, FilterOrderDesc,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err)
|
||||
|
|
@ -395,7 +393,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
|
|||
},
|
||||
eventFilter.Senders, eventFilter.NotSenders,
|
||||
eventFilter.Types, eventFilter.NotTypes,
|
||||
nil, eventFilter.Limit, FilterOrderAsc,
|
||||
nil, eventFilter.ContainsURL, eventFilter.Limit, FilterOrderAsc,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
|
||||
|
|
@ -421,21 +419,50 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
|
|||
// selectEvents returns the events for the given event IDs. If an event is
|
||||
// missing from the database, it will be omitted.
|
||||
func (s *outputRoomEventsStatements) SelectEvents(
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool,
|
||||
) ([]types.StreamEvent, error) {
|
||||
var returnEvents []types.StreamEvent
|
||||
stmt := sqlutil.TxStmt(txn, s.selectEventsStmt)
|
||||
for _, eventID := range eventIDs {
|
||||
rows, err := stmt.QueryContext(ctx, eventID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if streamEvents, err := rowsToStreamEvents(rows); err == nil {
|
||||
returnEvents = append(returnEvents, streamEvents...)
|
||||
}
|
||||
internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
|
||||
iEventIDs := make([]interface{}, len(eventIDs))
|
||||
for i := range eventIDs {
|
||||
iEventIDs[i] = eventIDs[i]
|
||||
}
|
||||
return returnEvents, nil
|
||||
selectSQL := strings.Replace(selectEventsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1)
|
||||
|
||||
if filter == nil {
|
||||
filter = &gomatrixserverlib.RoomEventFilter{Limit: 20}
|
||||
}
|
||||
stmt, params, err := prepareWithFilters(
|
||||
s.db, txn, selectSQL, iEventIDs,
|
||||
filter.Senders, filter.NotSenders,
|
||||
filter.Types, filter.NotTypes,
|
||||
nil, filter.ContainsURL, filter.Limit, FilterOrderAsc,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows, err := stmt.QueryContext(ctx, params...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
|
||||
streamEvents, err := rowsToStreamEvents(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if preserveOrder {
|
||||
var returnEvents []types.StreamEvent
|
||||
eventMap := make(map[string]types.StreamEvent)
|
||||
for _, ev := range streamEvents {
|
||||
eventMap[ev.EventID()] = ev
|
||||
}
|
||||
for _, eventID := range eventIDs {
|
||||
ev, ok := eventMap[eventID]
|
||||
if ok {
|
||||
returnEvents = append(returnEvents, ev)
|
||||
}
|
||||
}
|
||||
return returnEvents, nil
|
||||
}
|
||||
return streamEvents, nil
|
||||
}
|
||||
|
||||
func (s *outputRoomEventsStatements) DeleteEventsForRoom(
|
||||
|
|
@ -507,7 +534,7 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
|
|||
},
|
||||
filter.Senders, filter.NotSenders,
|
||||
filter.Types, filter.NotTypes,
|
||||
nil, filter.Limit, FilterOrderDesc,
|
||||
nil, filter.ContainsURL, filter.Limit, FilterOrderDesc,
|
||||
)
|
||||
|
||||
rows, err := stmt.QueryContext(ctx, params...)
|
||||
|
|
@ -543,7 +570,7 @@ func (s *outputRoomEventsStatements) SelectContextAfterEvent(
|
|||
},
|
||||
filter.Senders, filter.NotSenders,
|
||||
filter.Types, filter.NotTypes,
|
||||
nil, filter.Limit, FilterOrderAsc,
|
||||
nil, filter.ContainsURL, filter.Limit, FilterOrderAsc,
|
||||
)
|
||||
|
||||
rows, err := stmt.QueryContext(ctx, params...)
|
||||
|
|
|
|||
|
|
@ -78,7 +78,6 @@ type outputRoomEventsTopologyStatements struct {
|
|||
selectEventIDsInRangeDESCStmt *sql.Stmt
|
||||
selectPositionInTopologyStmt *sql.Stmt
|
||||
selectMaxPositionInTopologyStmt *sql.Stmt
|
||||
deleteTopologyForRoomStmt *sql.Stmt
|
||||
selectStreamToTopologicalPositionAscStmt *sql.Stmt
|
||||
selectStreamToTopologicalPositionDescStmt *sql.Stmt
|
||||
}
|
||||
|
|
@ -191,10 +190,3 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
|
|||
err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID)
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ const selectMaxPeekIDSQL = "" +
|
|||
|
||||
type peekStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
insertPeekStmt *sql.Stmt
|
||||
deletePeekStmt *sql.Stmt
|
||||
deletePeeksStmt *sql.Stmt
|
||||
|
|
@ -75,7 +75,7 @@ type peekStatements struct {
|
|||
selectMaxPeekIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks, error) {
|
||||
func NewSqlitePeeksTable(db *sql.DB, streamID *StreamIDStatements) (tables.Peeks, error) {
|
||||
_, err := db.Exec(peeksSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ const selectPresenceAfter = "" +
|
|||
|
||||
type presenceStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
upsertPresenceStmt *sql.Stmt
|
||||
upsertPresenceFromSyncStmt *sql.Stmt
|
||||
selectPresenceForUsersStmt *sql.Stmt
|
||||
|
|
@ -83,7 +83,7 @@ type presenceStatements struct {
|
|||
selectPresenceAfterStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqlitePresenceTable(db *sql.DB, streamID *streamIDStatements) (*presenceStatements, error) {
|
||||
func NewSqlitePresenceTable(db *sql.DB, streamID *StreamIDStatements) (*presenceStatements, error) {
|
||||
_, err := db.Exec(presenceSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -59,13 +59,13 @@ const selectMaxReceiptIDSQL = "" +
|
|||
|
||||
type receiptStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
upsertReceipt *sql.Stmt
|
||||
selectRoomReceipts *sql.Stmt
|
||||
selectMaxReceiptID *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Receipts, error) {
|
||||
func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Receipts, error) {
|
||||
_, err := db.Exec(receiptsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -32,12 +32,12 @@ const increaseStreamIDStmt = "" +
|
|||
"UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" +
|
||||
" RETURNING stream_id"
|
||||
|
||||
type streamIDStatements struct {
|
||||
type StreamIDStatements struct {
|
||||
db *sql.DB
|
||||
increaseStreamIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
|
||||
func (s *StreamIDStatements) Prepare(db *sql.DB) (err error) {
|
||||
s.db = db
|
||||
_, err = db.Exec(streamIDTableSchema)
|
||||
if err != nil {
|
||||
|
|
@ -49,31 +49,31 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
func (s *StreamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||
err = increaseStmt.QueryRowContext(ctx, "global").Scan(&pos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
func (s *StreamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||
err = increaseStmt.QueryRowContext(ctx, "receipt").Scan(&pos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
func (s *StreamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||
err = increaseStmt.QueryRowContext(ctx, "invite").Scan(&pos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
func (s *StreamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||
err = increaseStmt.QueryRowContext(ctx, "accountdata").Scan(&pos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
func (s *StreamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||
err = increaseStmt.QueryRowContext(ctx, "presence").Scan(&pos)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ type SyncServerDatasource struct {
|
|||
shared.Database
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
streamID streamIDStatements
|
||||
streamID StreamIDStatements
|
||||
}
|
||||
|
||||
// NewDatabase creates a new sync server database
|
||||
|
|
@ -49,7 +49,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
|
|||
}
|
||||
|
||||
func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) {
|
||||
if err = d.streamID.prepare(d.db); err != nil {
|
||||
if err = d.streamID.Prepare(d.db); err != nil {
|
||||
return err
|
||||
}
|
||||
accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID)
|
||||
|
|
|
|||
|
|
@ -1,121 +1,29 @@
|
|||
package storage_test
|
||||
|
||||
// TODO: Fix these tests
|
||||
/*
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var (
|
||||
ctx = context.Background()
|
||||
emptyStateKey = ""
|
||||
testOrigin = gomatrixserverlib.ServerName("hollow.knight")
|
||||
testRoomID = fmt.Sprintf("!hallownest:%s", testOrigin)
|
||||
testUserIDA = fmt.Sprintf("@hornet:%s", testOrigin)
|
||||
testUserIDB = fmt.Sprintf("@paleking:%s", testOrigin)
|
||||
testUserDeviceA = userapi.Device{
|
||||
UserID: testUserIDA,
|
||||
ID: "device_id_A",
|
||||
DisplayName: "Device A",
|
||||
}
|
||||
testRoomVersion = gomatrixserverlib.RoomVersionV4
|
||||
testKeyID = gomatrixserverlib.KeyID("ed25519:storage_test")
|
||||
testPrivateKey = ed25519.NewKeyFromSeed([]byte{
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
||||
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
|
||||
})
|
||||
)
|
||||
var ctx = context.Background()
|
||||
|
||||
func MustCreateEvent(t *testing.T, roomID string, prevs []*gomatrixserverlib.HeaderedEvent, b *gomatrixserverlib.EventBuilder) *gomatrixserverlib.HeaderedEvent {
|
||||
b.RoomID = roomID
|
||||
if prevs != nil {
|
||||
prevIDs := make([]string, len(prevs))
|
||||
for i := range prevs {
|
||||
prevIDs[i] = prevs[i].EventID()
|
||||
}
|
||||
b.PrevEvents = prevIDs
|
||||
}
|
||||
e, err := b.Build(time.Now(), testOrigin, testKeyID, testPrivateKey, testRoomVersion)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to build event: %s", err)
|
||||
}
|
||||
return e.Headered(testRoomVersion)
|
||||
}
|
||||
|
||||
func MustCreateDatabase(t *testing.T) storage.Database {
|
||||
dbname := fmt.Sprintf("test_%s.db", t.Name())
|
||||
if _, err := os.Stat(dbname); err == nil {
|
||||
if err = os.Remove(dbname); err != nil {
|
||||
t.Fatalf("tried to delete stale test database but failed: %s", err)
|
||||
}
|
||||
}
|
||||
db, err := sqlite3.NewDatabase(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(fmt.Sprintf("file:%s", dbname)),
|
||||
func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := storage.NewSyncServerDatasource(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewSyncServerDatasource returned %s", err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// Create a list of events which include a create event, join event and some messages.
|
||||
func SimpleRoom(t *testing.T, roomID, userA, userB string) (msgs []*gomatrixserverlib.HeaderedEvent, state []*gomatrixserverlib.HeaderedEvent) {
|
||||
var events []*gomatrixserverlib.HeaderedEvent
|
||||
events = append(events, MustCreateEvent(t, roomID, nil, &gomatrixserverlib.EventBuilder{
|
||||
Content: []byte(fmt.Sprintf(`{"room_version":"4","creator":"%s"}`, userA)),
|
||||
Type: "m.room.create",
|
||||
StateKey: &emptyStateKey,
|
||||
Sender: userA,
|
||||
Depth: int64(len(events) + 1),
|
||||
}))
|
||||
state = append(state, events[len(events)-1])
|
||||
events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{
|
||||
Content: []byte(`{"membership":"join"}`),
|
||||
Type: "m.room.member",
|
||||
StateKey: &userA,
|
||||
Sender: userA,
|
||||
Depth: int64(len(events) + 1),
|
||||
}))
|
||||
state = append(state, events[len(events)-1])
|
||||
for i := 0; i < 10; i++ {
|
||||
events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{
|
||||
Content: []byte(fmt.Sprintf(`{"body":"Message A %d"}`, i+1)),
|
||||
Type: "m.room.message",
|
||||
Sender: userA,
|
||||
Depth: int64(len(events) + 1),
|
||||
}))
|
||||
}
|
||||
events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{
|
||||
Content: []byte(`{"membership":"join"}`),
|
||||
Type: "m.room.member",
|
||||
StateKey: &userB,
|
||||
Sender: userB,
|
||||
Depth: int64(len(events) + 1),
|
||||
}))
|
||||
state = append(state, events[len(events)-1])
|
||||
for i := 0; i < 10; i++ {
|
||||
events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{
|
||||
Content: []byte(fmt.Sprintf(`{"body":"Message B %d"}`, i+1)),
|
||||
Type: "m.room.message",
|
||||
Sender: userB,
|
||||
Depth: int64(len(events) + 1),
|
||||
}))
|
||||
}
|
||||
|
||||
return events, state
|
||||
return db, close
|
||||
}
|
||||
|
||||
func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (positions []types.StreamPosition) {
|
||||
|
|
@ -131,206 +39,158 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
|
|||
if err != nil {
|
||||
t.Fatalf("WriteEvent failed: %s", err)
|
||||
}
|
||||
fmt.Println("Event ID", ev.EventID(), " spos=", pos, "depth=", ev.Depth())
|
||||
t.Logf("Event ID %s spos=%v depth=%v", ev.EventID(), pos, ev.Depth())
|
||||
positions = append(positions, pos)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func TestWriteEvents(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := MustCreateDatabase(t)
|
||||
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
|
||||
MustWriteEvents(t, db, events)
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
alice := test.NewUser()
|
||||
r := test.NewRoom(t, alice)
|
||||
db, close := MustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
MustWriteEvents(t, db, r.Events())
|
||||
})
|
||||
}
|
||||
|
||||
// These tests assert basic functionality of the IncrementalSync and CompleteSync functions.
|
||||
func TestSyncResponse(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := MustCreateDatabase(t)
|
||||
events, state := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
|
||||
positions := MustWriteEvents(t, db, events)
|
||||
latest, err := db.SyncPosition(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get SyncPosition: %s", err)
|
||||
}
|
||||
// These tests assert basic functionality of RecentEvents for PDUs
|
||||
func TestRecentEventsPDU(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := MustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
alice := test.NewUser()
|
||||
// dummy room to make sure SQL queries are filtering on room ID
|
||||
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
|
||||
|
||||
testCases := []struct {
|
||||
Name string
|
||||
DoSync func() (*types.Response, error)
|
||||
WantTimeline []*gomatrixserverlib.HeaderedEvent
|
||||
WantState []*gomatrixserverlib.HeaderedEvent
|
||||
}{
|
||||
// The purpose of this test is to make sure that incremental syncs are including up to the latest events.
|
||||
// It's a basic sanity test that sync works. It creates a `since` token that is on the penultimate event.
|
||||
// It makes sure the response includes the final event.
|
||||
{
|
||||
Name: "IncrementalSync penultimate",
|
||||
DoSync: func() (*types.Response, error) {
|
||||
from := types.StreamingToken{ // pretend we are at the penultimate event
|
||||
PDUPosition: positions[len(positions)-2],
|
||||
// actual test room
|
||||
r := test.NewRoom(t, alice)
|
||||
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hi"})
|
||||
events := r.Events()
|
||||
positions := MustWriteEvents(t, db, events)
|
||||
|
||||
// dummy room to make sure SQL queries are filtering on room ID
|
||||
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
|
||||
|
||||
latest, err := db.MaxStreamPositionForPDUs(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
Name string
|
||||
From types.StreamPosition
|
||||
To types.StreamPosition
|
||||
Limit int
|
||||
ReverseOrder bool
|
||||
WantEvents []*gomatrixserverlib.HeaderedEvent
|
||||
WantLimited bool
|
||||
}{
|
||||
// The purpose of this test is to make sure that incremental syncs are including up to the latest events.
|
||||
// It's a basic sanity test that sync works. It creates a streaming position that is on the penultimate event.
|
||||
// It makes sure the response includes the final event.
|
||||
{
|
||||
Name: "penultimate",
|
||||
From: positions[len(positions)-2], // pretend we are at the penultimate event
|
||||
To: latest,
|
||||
Limit: 100,
|
||||
WantEvents: events[len(events)-1:],
|
||||
WantLimited: false,
|
||||
},
|
||||
// The purpose of this test is to check that limits can be applied and work.
|
||||
// This is critical for big rooms hence the test here.
|
||||
{
|
||||
Name: "limited",
|
||||
From: 0,
|
||||
To: latest,
|
||||
Limit: 1,
|
||||
WantEvents: events[len(events)-1:],
|
||||
WantLimited: true,
|
||||
},
|
||||
// The purpose of this test is to check that we can return every event with a high
|
||||
// enough limit
|
||||
{
|
||||
Name: "large limited",
|
||||
From: 0,
|
||||
To: latest,
|
||||
Limit: 100,
|
||||
WantEvents: events,
|
||||
WantLimited: false,
|
||||
},
|
||||
// The purpose of this test is to check that we can return events in reverse order
|
||||
{
|
||||
Name: "reverse",
|
||||
From: positions[len(positions)-3], // 2 events back
|
||||
To: latest,
|
||||
Limit: 100,
|
||||
ReverseOrder: true,
|
||||
WantEvents: test.Reversed(events[len(events)-2:]),
|
||||
WantLimited: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i := range testCases {
|
||||
tc := testCases[i]
|
||||
t.Run(tc.Name, func(st *testing.T) {
|
||||
var filter gomatrixserverlib.RoomEventFilter
|
||||
filter.Limit = tc.Limit
|
||||
gotEvents, limited, err := db.RecentEvents(ctx, r.ID, types.Range{
|
||||
From: tc.From,
|
||||
To: tc.To,
|
||||
}, &filter, !tc.ReverseOrder, true)
|
||||
if err != nil {
|
||||
st.Fatalf("failed to do sync: %s", err)
|
||||
}
|
||||
res := types.NewResponse()
|
||||
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
|
||||
},
|
||||
WantTimeline: events[len(events)-1:],
|
||||
},
|
||||
// The purpose of this test is to check that passing a `numRecentEventsPerRoom` correctly limits the
|
||||
// number of returned events. This is critical for big rooms hence the test here.
|
||||
{
|
||||
Name: "IncrementalSync limited",
|
||||
DoSync: func() (*types.Response, error) {
|
||||
from := types.StreamingToken{ // pretend we are 10 events behind
|
||||
PDUPosition: positions[len(positions)-11],
|
||||
if limited != tc.WantLimited {
|
||||
st.Errorf("got limited=%v want %v", limited, tc.WantLimited)
|
||||
}
|
||||
res := types.NewResponse()
|
||||
// limit is set to 5
|
||||
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
|
||||
},
|
||||
// want the last 5 events, NOT the last 10.
|
||||
WantTimeline: events[len(events)-5:],
|
||||
},
|
||||
// The purpose of this test is to check that CompleteSync returns all the current state as well as
|
||||
// honouring the `numRecentEventsPerRoom` value
|
||||
{
|
||||
Name: "CompleteSync limited",
|
||||
DoSync: func() (*types.Response, error) {
|
||||
res := types.NewResponse()
|
||||
// limit set to 5
|
||||
return db.CompleteSync(ctx, res, testUserDeviceA, 5)
|
||||
},
|
||||
// want the last 5 events
|
||||
WantTimeline: events[len(events)-5:],
|
||||
// want all state for the room
|
||||
WantState: state,
|
||||
},
|
||||
// The purpose of this test is to check that CompleteSync can return everything with a high enough
|
||||
// `numRecentEventsPerRoom`.
|
||||
{
|
||||
Name: "CompleteSync",
|
||||
DoSync: func() (*types.Response, error) {
|
||||
res := types.NewResponse()
|
||||
return db.CompleteSync(ctx, res, testUserDeviceA, len(events)+1)
|
||||
},
|
||||
WantTimeline: events,
|
||||
// We want no state at all as that field in /sync is the delta between the token (beginning of time)
|
||||
// and the START of the timeline.
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.Name, func(st *testing.T) {
|
||||
res, err := tc.DoSync()
|
||||
if err != nil {
|
||||
st.Fatalf("failed to do sync: %s", err)
|
||||
}
|
||||
next := types.StreamingToken{
|
||||
PDUPosition: latest.PDUPosition,
|
||||
TypingPosition: latest.TypingPosition,
|
||||
ReceiptPosition: latest.ReceiptPosition,
|
||||
SendToDevicePosition: latest.SendToDevicePosition,
|
||||
}
|
||||
if res.NextBatch.String() != next.String() {
|
||||
st.Errorf("NextBatch got %s want %s", res.NextBatch, next.String())
|
||||
}
|
||||
roomRes, ok := res.Rooms.Join[testRoomID]
|
||||
if !ok {
|
||||
st.Fatalf("IncrementalSync response missing room %s - response: %+v", testRoomID, res)
|
||||
}
|
||||
assertEventsEqual(st, "state for "+testRoomID, false, roomRes.State.Events, tc.WantState)
|
||||
assertEventsEqual(st, "timeline for "+testRoomID, false, roomRes.Timeline.Events, tc.WantTimeline)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := MustCreateDatabase(t)
|
||||
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
|
||||
positions := MustWriteEvents(t, db, events)
|
||||
latest, err := db.SyncPosition(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get SyncPosition: %s", err)
|
||||
}
|
||||
from := types.StreamingToken{
|
||||
PDUPosition: positions[len(positions)-2],
|
||||
}
|
||||
|
||||
res := types.NewResponse()
|
||||
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to IncrementalSync with latest token")
|
||||
}
|
||||
roomRes, ok := res.Rooms.Join[testRoomID]
|
||||
if !ok {
|
||||
t.Fatalf("IncrementalSync response missing room %s - response: %+v", testRoomID, res)
|
||||
}
|
||||
// returns the last event "Message 10"
|
||||
assertEventsEqual(t, "IncrementalSync Timeline", false, roomRes.Timeline.Events, reversed(events[len(events)-1:]))
|
||||
|
||||
prev := roomRes.Timeline.PrevBatch.String()
|
||||
if prev == "" {
|
||||
t.Fatalf("IncrementalSync expected prev_batch token")
|
||||
}
|
||||
prevBatchToken, err := types.NewTopologyTokenFromString(prev)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to NewTopologyTokenFromString : %s", err)
|
||||
}
|
||||
// backpaginate 5 messages starting at the latest position.
|
||||
// head towards the beginning of time
|
||||
to := types.TopologyToken{}
|
||||
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &prevBatchToken, &to, testRoomID, 5, true)
|
||||
if err != nil {
|
||||
t.Fatalf("GetEventsInRange returned an error: %s", err)
|
||||
}
|
||||
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
|
||||
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-6:len(events)-1]))
|
||||
}
|
||||
|
||||
// The purpose of this test is to ensure that backfill does indeed go backwards, using a stream token.
|
||||
func TestGetEventsInRangeWithStreamToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := MustCreateDatabase(t)
|
||||
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
|
||||
MustWriteEvents(t, db, events)
|
||||
latest, err := db.SyncPosition(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get SyncPosition: %s", err)
|
||||
}
|
||||
// head towards the beginning of time
|
||||
to := types.StreamingToken{}
|
||||
|
||||
// backpaginate 5 messages starting at the latest position.
|
||||
paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true)
|
||||
if err != nil {
|
||||
t.Fatalf("GetEventsInRange returned an error: %s", err)
|
||||
}
|
||||
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
|
||||
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:]))
|
||||
if len(gotEvents) != len(tc.WantEvents) {
|
||||
st.Errorf("got %d events, want %d", len(gotEvents), len(tc.WantEvents))
|
||||
}
|
||||
for j := range gotEvents {
|
||||
if !reflect.DeepEqual(gotEvents[j].JSON(), tc.WantEvents[j].JSON()) {
|
||||
st.Errorf("event %d got %s want %s", j, string(gotEvents[j].JSON()), string(tc.WantEvents[j].JSON()))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token
|
||||
func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := MustCreateDatabase(t)
|
||||
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
|
||||
MustWriteEvents(t, db, events)
|
||||
from, err := db.MaxTopologicalPosition(ctx, testRoomID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
|
||||
}
|
||||
// head towards the beginning of time
|
||||
to := types.TopologyToken{}
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := MustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
alice := test.NewUser()
|
||||
r := test.NewRoom(t, alice)
|
||||
for i := 0; i < 10; i++ {
|
||||
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)})
|
||||
}
|
||||
events := r.Events()
|
||||
_ = MustWriteEvents(t, db, events)
|
||||
|
||||
// backpaginate 5 messages starting at the latest position.
|
||||
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, testRoomID, 5, true)
|
||||
if err != nil {
|
||||
t.Fatalf("GetEventsInRange returned an error: %s", err)
|
||||
}
|
||||
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
|
||||
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:]))
|
||||
from, err := db.MaxTopologicalPosition(ctx, r.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
|
||||
}
|
||||
t.Logf("max topo pos = %+v", from)
|
||||
// head towards the beginning of time
|
||||
to := types.TopologyToken{}
|
||||
|
||||
// backpaginate 5 messages starting at the latest position.
|
||||
filter := &gomatrixserverlib.RoomEventFilter{Limit: 5}
|
||||
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true)
|
||||
if err != nil {
|
||||
t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
|
||||
}
|
||||
gots := db.StreamEventsToEvents(nil, paginatedEvents)
|
||||
test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
|
||||
})
|
||||
}
|
||||
|
||||
/*
|
||||
// The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth.
|
||||
// For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent
|
||||
// will appear FIRST when going backwards. This test creates a DAG like:
|
||||
|
|
@ -740,12 +600,4 @@ func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *typ
|
|||
tok.Decrement()
|
||||
return &tok
|
||||
}
|
||||
|
||||
func reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
|
||||
out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
|
||||
for i := 0; i < len(in); i++ {
|
||||
out[i] = in[len(in)-i-1]
|
||||
}
|
||||
return out
|
||||
}
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ type Events interface {
|
|||
SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
|
||||
// SelectEarlyEvents returns the earliest events in the given room.
|
||||
SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error)
|
||||
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error)
|
||||
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool) ([]types.StreamEvent, error)
|
||||
UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error
|
||||
// DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely.
|
||||
DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
|
||||
|
|
@ -84,8 +84,6 @@ type Topology interface {
|
|||
SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error)
|
||||
// SelectMaxPositionInTopology returns the event which has the highest depth, and if there are multiple, the event with the highest stream position.
|
||||
SelectMaxPositionInTopology(ctx context.Context, txn *sql.Tx, roomID string) (depth types.StreamPosition, spos types.StreamPosition, err error)
|
||||
// DeleteTopologyForRoom removes all topological information for a room. This should only be done when removing the room entirely.
|
||||
DeleteTopologyForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
|
||||
// SelectStreamToTopologicalPosition converts a stream position to a topological position by finding the nearest topological position in the room.
|
||||
SelectStreamToTopologicalPosition(ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, forward bool) (topoPos types.StreamPosition, err error)
|
||||
}
|
||||
|
|
@ -132,8 +130,6 @@ type BackwardsExtremities interface {
|
|||
SelectBackwardExtremitiesForRoom(ctx context.Context, roomID string) (bwExtrems map[string][]string, err error)
|
||||
// DeleteBackwardExtremity removes a backwards extremity for a room, if one existed.
|
||||
DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error)
|
||||
// DeleteBackwardExtremitiesFoorRoomID removes all backward extremities for a room. This should only be done when removing the room entirely.
|
||||
DeleteBackwardExtremitiesForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
|
||||
}
|
||||
|
||||
// SendToDevice tracks send-to-device messages which are sent to individual
|
||||
|
|
@ -173,7 +169,6 @@ type Receipts interface {
|
|||
|
||||
type Memberships interface {
|
||||
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
|
||||
SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string) (eventID string, streamPos, topologyPos types.StreamPosition, err error)
|
||||
SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error)
|
||||
}
|
||||
|
||||
|
|
|
|||
105
syncapi/storage/tables/output_room_events_test.go
Normal file
105
syncapi/storage/tables/output_room_events_test.go
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
package tables_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events, *sql.DB, func()) {
|
||||
t.Helper()
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open db: %s", err)
|
||||
}
|
||||
|
||||
var tab tables.Events
|
||||
switch dbType {
|
||||
case test.DBTypePostgres:
|
||||
tab, err = postgres.NewPostgresEventsTable(db)
|
||||
case test.DBTypeSQLite:
|
||||
var stream sqlite3.StreamIDStatements
|
||||
if err = stream.Prepare(db); err != nil {
|
||||
t.Fatalf("failed to prepare stream stmts: %s", err)
|
||||
}
|
||||
tab, err = sqlite3.NewSqliteEventsTable(db, &stream)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make new table: %s", err)
|
||||
}
|
||||
return tab, db, close
|
||||
}
|
||||
|
||||
func TestOutputRoomEventsTable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
alice := test.NewUser()
|
||||
room := test.NewRoom(t, alice)
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
tab, db, close := newOutputRoomEventsTable(t, dbType)
|
||||
defer close()
|
||||
events := room.Events()
|
||||
err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
|
||||
for _, ev := range events {
|
||||
_, err := tab.InsertEvent(ctx, txn, ev, nil, nil, nil, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to InsertEvent: %s", err)
|
||||
}
|
||||
}
|
||||
// order = 2,0,3,1
|
||||
wantEventIDs := []string{
|
||||
events[2].EventID(), events[0].EventID(), events[3].EventID(), events[1].EventID(),
|
||||
}
|
||||
gotEvents, err := tab.SelectEvents(ctx, txn, wantEventIDs, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to SelectEvents: %s", err)
|
||||
}
|
||||
gotEventIDs := make([]string, len(gotEvents))
|
||||
for i := range gotEvents {
|
||||
gotEventIDs[i] = gotEvents[i].EventID()
|
||||
}
|
||||
if !reflect.DeepEqual(gotEventIDs, wantEventIDs) {
|
||||
return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventIDs)
|
||||
}
|
||||
|
||||
// Test that contains_url is correctly populated
|
||||
urlEv := room.CreateEvent(t, alice, "m.text", map[string]interface{}{
|
||||
"body": "test.txt",
|
||||
"url": "mxc://test.txt",
|
||||
})
|
||||
if _, err = tab.InsertEvent(ctx, txn, urlEv, nil, nil, nil, false); err != nil {
|
||||
return fmt.Errorf("failed to InsertEvent: %s", err)
|
||||
}
|
||||
wantEventID := []string{urlEv.EventID()}
|
||||
t := true
|
||||
gotEvents, err = tab.SelectEvents(ctx, txn, wantEventID, &gomatrixserverlib.RoomEventFilter{Limit: 1, ContainsURL: &t}, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to SelectEvents: %s", err)
|
||||
}
|
||||
gotEventIDs = make([]string, len(gotEvents))
|
||||
for i := range gotEvents {
|
||||
gotEventIDs[i] = gotEvents[i].EventID()
|
||||
}
|
||||
if !reflect.DeepEqual(gotEventIDs, wantEventID) {
|
||||
return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventID)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
91
syncapi/storage/tables/topology_test.go
Normal file
91
syncapi/storage/tables/topology_test.go
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
package tables_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
)
|
||||
|
||||
func newTopologyTable(t *testing.T, dbType test.DBType) (tables.Topology, *sql.DB, func()) {
|
||||
t.Helper()
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open db: %s", err)
|
||||
}
|
||||
|
||||
var tab tables.Topology
|
||||
switch dbType {
|
||||
case test.DBTypePostgres:
|
||||
tab, err = postgres.NewPostgresTopologyTable(db)
|
||||
case test.DBTypeSQLite:
|
||||
tab, err = sqlite3.NewSqliteTopologyTable(db)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make new table: %s", err)
|
||||
}
|
||||
return tab, db, close
|
||||
}
|
||||
|
||||
func TestTopologyTable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
alice := test.NewUser()
|
||||
room := test.NewRoom(t, alice)
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
tab, db, close := newTopologyTable(t, dbType)
|
||||
defer close()
|
||||
events := room.Events()
|
||||
err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
|
||||
var highestPos types.StreamPosition
|
||||
for i, ev := range events {
|
||||
topoPos, err := tab.InsertEventInTopology(ctx, txn, ev, types.StreamPosition(i))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to InsertEventInTopology: %s", err)
|
||||
}
|
||||
// topo pos = depth, depth starts at 1, hence 1+i
|
||||
if topoPos != types.StreamPosition(1+i) {
|
||||
return fmt.Errorf("got topo pos %d want %d", topoPos, 1+i)
|
||||
}
|
||||
highestPos = topoPos + 1
|
||||
}
|
||||
// check ordering works without limit
|
||||
eventIDs, err := tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
|
||||
}
|
||||
test.AssertEventIDsEqual(t, eventIDs, events[:])
|
||||
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
|
||||
}
|
||||
test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[:]))
|
||||
// check ordering works with limit
|
||||
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
|
||||
}
|
||||
test.AssertEventIDsEqual(t, eventIDs, events[:3])
|
||||
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
|
||||
}
|
||||
test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[len(events)-3:]))
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -6,6 +6,7 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
|
@ -26,7 +27,8 @@ type PDUStreamProvider struct {
|
|||
|
||||
tasks chan func()
|
||||
workers atomic.Int32
|
||||
userAPI userapi.UserInternalAPI
|
||||
// userID+deviceID -> lazy loading cache
|
||||
lazyLoadCache *caching.LazyLoadCache
|
||||
}
|
||||
|
||||
func (p *PDUStreamProvider) worker() {
|
||||
|
|
@ -188,7 +190,7 @@ func (p *PDUStreamProvider) IncrementalSync(
|
|||
newPos = from
|
||||
for _, delta := range stateDeltas {
|
||||
var pos types.StreamPosition
|
||||
if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, r, delta, &eventFilter, req.Response); err != nil {
|
||||
if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, r, delta, &eventFilter, &stateFilter, req.Response); err != nil {
|
||||
req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed")
|
||||
return to
|
||||
}
|
||||
|
|
@ -209,6 +211,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
|||
r types.Range,
|
||||
delta types.StateDelta,
|
||||
eventFilter *gomatrixserverlib.RoomEventFilter,
|
||||
stateFilter *gomatrixserverlib.StateFilter,
|
||||
res *types.Response,
|
||||
) (types.StreamPosition, error) {
|
||||
if delta.MembershipPos > 0 && delta.Membership == gomatrixserverlib.Leave {
|
||||
|
|
@ -247,7 +250,8 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
|||
// room that were returned.
|
||||
latestPosition := r.To
|
||||
updateLatestPosition := func(mostRecentEventID string) {
|
||||
if _, pos, err := p.DB.PositionInTopology(ctx, mostRecentEventID); err == nil {
|
||||
var pos types.StreamPosition
|
||||
if _, pos, err = p.DB.PositionInTopology(ctx, mostRecentEventID); err == nil {
|
||||
switch {
|
||||
case r.Backwards && pos > latestPosition:
|
||||
fallthrough
|
||||
|
|
@ -263,6 +267,19 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
|||
updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID())
|
||||
}
|
||||
|
||||
if stateFilter.LazyLoadMembers {
|
||||
if err != nil {
|
||||
return r.From, err
|
||||
}
|
||||
delta.StateEvents, err = p.lazyLoadMembers(
|
||||
ctx, delta.RoomID, true, limited, stateFilter.IncludeRedundantMembers,
|
||||
device, recentEvents, delta.StateEvents,
|
||||
)
|
||||
if err != nil {
|
||||
return r.From, err
|
||||
}
|
||||
}
|
||||
|
||||
hasMembershipChange := false
|
||||
for _, recentEvent := range recentStreamEvents {
|
||||
if recentEvent.Type() == gomatrixserverlib.MRoomMember && recentEvent.StateKey() != nil {
|
||||
|
|
@ -402,6 +419,20 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
|||
// "Can sync a room with a message with a transaction id" - which does a complete sync to check.
|
||||
recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents)
|
||||
stateEvents = removeDuplicates(stateEvents, recentEvents)
|
||||
|
||||
if stateFilter.LazyLoadMembers {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stateEvents, err = p.lazyLoadMembers(ctx, roomID,
|
||||
false, limited, stateFilter.IncludeRedundantMembers,
|
||||
device, recentEvents, stateEvents,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
jr = types.NewJoinResponse()
|
||||
jr.Summary.JoinedMemberCount = &joinedCount
|
||||
jr.Summary.InvitedMemberCount = &invitedCount
|
||||
|
|
@ -412,6 +443,69 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
|||
return jr, nil
|
||||
}
|
||||
|
||||
func (p *PDUStreamProvider) lazyLoadMembers(
|
||||
ctx context.Context, roomID string,
|
||||
incremental, limited, includeRedundant bool,
|
||||
device *userapi.Device,
|
||||
timelineEvents, stateEvents []*gomatrixserverlib.HeaderedEvent,
|
||||
) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||
if len(timelineEvents) == 0 {
|
||||
return stateEvents, nil
|
||||
}
|
||||
// Work out which memberships to include
|
||||
timelineUsers := make(map[string]struct{})
|
||||
if !incremental {
|
||||
timelineUsers[device.UserID] = struct{}{}
|
||||
}
|
||||
// Add all users the client doesn't know about yet to a list
|
||||
for _, event := range timelineEvents {
|
||||
// Membership is not yet cached, add it to the list
|
||||
if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, event.Sender()); !ok {
|
||||
timelineUsers[event.Sender()] = struct{}{}
|
||||
}
|
||||
}
|
||||
// Preallocate with the same amount, even if it will end up with fewer values
|
||||
newStateEvents := make([]*gomatrixserverlib.HeaderedEvent, 0, len(stateEvents))
|
||||
// Remove existing membership events we don't care about, e.g. users not in the timeline.events
|
||||
for _, event := range stateEvents {
|
||||
if event.Type() == gomatrixserverlib.MRoomMember && event.StateKey() != nil {
|
||||
// If this is a gapped incremental sync, we still want this membership
|
||||
isGappedIncremental := limited && incremental
|
||||
// We want this users membership event, keep it in the list
|
||||
_, ok := timelineUsers[event.Sender()]
|
||||
wantMembership := ok || isGappedIncremental
|
||||
if wantMembership {
|
||||
newStateEvents = append(newStateEvents, event)
|
||||
if !includeRedundant {
|
||||
p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, event.Sender(), event.EventID())
|
||||
}
|
||||
delete(timelineUsers, event.Sender())
|
||||
}
|
||||
} else {
|
||||
newStateEvents = append(newStateEvents, event)
|
||||
}
|
||||
}
|
||||
wantUsers := make([]string, 0, len(timelineUsers))
|
||||
for userID := range timelineUsers {
|
||||
wantUsers = append(wantUsers, userID)
|
||||
}
|
||||
// Query missing membership events
|
||||
memberships, err := p.DB.GetStateEventsForRoom(ctx, roomID, &gomatrixserverlib.StateFilter{
|
||||
Limit: 100,
|
||||
Senders: &wantUsers,
|
||||
Types: &[]string{gomatrixserverlib.MRoomMember},
|
||||
})
|
||||
if err != nil {
|
||||
return stateEvents, err
|
||||
}
|
||||
// cache the membership events
|
||||
for _, membership := range memberships {
|
||||
p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, membership.Sender(), membership.EventID())
|
||||
}
|
||||
stateEvents = append(newStateEvents, memberships...)
|
||||
return stateEvents, nil
|
||||
}
|
||||
|
||||
// addIgnoredUsersToFilter adds ignored users to the eventfilter and
|
||||
// the syncreq itself for further use in streams.
|
||||
func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error {
|
||||
|
|
@ -423,8 +517,12 @@ func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *ty
|
|||
return err
|
||||
}
|
||||
req.IgnoredUsers = *ignores
|
||||
userList := make([]string, 0, len(ignores.List))
|
||||
for userID := range ignores.List {
|
||||
eventFilter.NotSenders = append(eventFilter.NotSenders, userID)
|
||||
userList = append(userList, userID)
|
||||
}
|
||||
if len(userList) > 0 {
|
||||
eventFilter.NotSenders = &userList
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,12 +27,12 @@ type Streams struct {
|
|||
func NewSyncStreamProviders(
|
||||
d storage.Database, userAPI userapi.UserInternalAPI,
|
||||
rsAPI rsapi.RoomserverInternalAPI, keyAPI keyapi.KeyInternalAPI,
|
||||
eduCache *caching.EDUCache, notifier *notifier.Notifier,
|
||||
eduCache *caching.EDUCache, lazyLoadCache *caching.LazyLoadCache, notifier *notifier.Notifier,
|
||||
) *Streams {
|
||||
streams := &Streams{
|
||||
PDUStreamProvider: &PDUStreamProvider{
|
||||
StreamProvider: StreamProvider{DB: d},
|
||||
userAPI: userAPI,
|
||||
lazyLoadCache: lazyLoadCache,
|
||||
},
|
||||
TypingStreamProvider: &TypingStreamProvider{
|
||||
StreamProvider: StreamProvider{DB: d},
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
package sync
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
|
@ -60,10 +61,10 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
|
|||
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||
return nil, fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
|
||||
}
|
||||
if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err != nil {
|
||||
if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err != nil && err != sql.ErrNoRows {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("syncDB.GetFilter failed")
|
||||
return nil, fmt.Errorf("syncDB.GetFilter: %w", err)
|
||||
} else {
|
||||
} else if f != nil {
|
||||
filter = *f
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -67,6 +67,9 @@ func NewRequestPool(
|
|||
streams *streams.Streams, notifier *notifier.Notifier,
|
||||
producer PresencePublisher,
|
||||
) *RequestPool {
|
||||
prometheus.MustRegister(
|
||||
activeSyncRequests, waitingSyncRequests,
|
||||
)
|
||||
rp := &RequestPool{
|
||||
db: db,
|
||||
cfg: cfg,
|
||||
|
|
@ -183,12 +186,6 @@ func (rp *RequestPool) updateLastSeen(req *http.Request, device *userapi.Device)
|
|||
rp.lastseen.Store(device.UserID+device.ID, time.Now())
|
||||
}
|
||||
|
||||
func init() {
|
||||
prometheus.MustRegister(
|
||||
activeSyncRequests, waitingSyncRequests,
|
||||
)
|
||||
}
|
||||
|
||||
var activeSyncRequests = prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: "dendrite",
|
||||
|
|
|
|||
|
|
@ -57,8 +57,12 @@ func AddPublicRoutes(
|
|||
}
|
||||
|
||||
eduCache := caching.NewTypingCache()
|
||||
lazyLoadCache, err := caching.NewLazyLoadCache()
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panicf("failed to create lazy loading cache")
|
||||
}
|
||||
notifier := notifier.NewNotifier()
|
||||
streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, notifier)
|
||||
streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, lazyLoadCache, notifier)
|
||||
notifier.SetCurrentPosition(streams.Latest(context.Background()))
|
||||
if err = notifier.Load(context.Background(), syncDB); err != nil {
|
||||
logrus.WithError(err).Panicf("failed to load notifier ")
|
||||
|
|
|
|||
|
|
@ -696,4 +696,17 @@ Room state after a rejected message event is the same as before
|
|||
Room state after a rejected state event is the same as before
|
||||
Ignore user in existing room
|
||||
Ignore invite in full sync
|
||||
Ignore invite in incremental sync
|
||||
Ignore invite in incremental sync
|
||||
A filtered timeline reaches its limit
|
||||
A change to displayname should not result in a full state sync
|
||||
Can fetch images in room
|
||||
The only membership state included in an initial sync is for all the senders in the timeline
|
||||
The only membership state included in an incremental sync is for senders in the timeline
|
||||
Old members are included in gappy incr LL sync if they start speaking
|
||||
We do send redundant membership state across incremental syncs if asked
|
||||
Rejecting invite over federation doesn't break incremental /sync
|
||||
Gapped incremental syncs include all state changes
|
||||
Old leaves are present in gapped incremental syncs
|
||||
Leaves are present in non-gapped incremental syncs
|
||||
Members from the gap are included in gappy incr LL sync
|
||||
Presence can be set from sync
|
||||
170
test/db.go
Normal file
170
test/db.go
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package test
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"testing"
|
||||
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
type DBType int
|
||||
|
||||
var DBTypeSQLite DBType = 1
|
||||
var DBTypePostgres DBType = 2
|
||||
|
||||
var Quiet = false
|
||||
|
||||
func createLocalDB(dbName string) {
|
||||
if !Quiet {
|
||||
fmt.Println("Note: tests require a postgres install accessible to the current user")
|
||||
}
|
||||
createDB := exec.Command("createdb", dbName)
|
||||
if !Quiet {
|
||||
createDB.Stdout = os.Stdout
|
||||
createDB.Stderr = os.Stderr
|
||||
}
|
||||
err := createDB.Run()
|
||||
if err != nil && !Quiet {
|
||||
fmt.Println("createLocalDB returned error:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func createRemoteDB(t *testing.T, dbName, user, connStr string) {
|
||||
db, err := sql.Open("postgres", connStr+" dbname=postgres")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open postgres conn with connstr=%s : %s", connStr, err)
|
||||
}
|
||||
_, err = db.Exec(fmt.Sprintf(`CREATE DATABASE %s;`, dbName))
|
||||
if err != nil {
|
||||
pqErr, ok := err.(*pq.Error)
|
||||
if !ok {
|
||||
t.Fatalf("failed to CREATE DATABASE: %s", err)
|
||||
}
|
||||
// we ignore duplicate database error as we expect this
|
||||
if pqErr.Code != "42P04" {
|
||||
t.Fatalf("failed to CREATE DATABASE with code=%s msg=%s", pqErr.Code, pqErr.Message)
|
||||
}
|
||||
}
|
||||
_, err = db.Exec(fmt.Sprintf(`GRANT ALL PRIVILEGES ON DATABASE %s TO %s`, dbName, user))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to GRANT: %s", err)
|
||||
}
|
||||
_ = db.Close()
|
||||
}
|
||||
|
||||
func currentUser() string {
|
||||
user, err := user.Current()
|
||||
if err != nil {
|
||||
if !Quiet {
|
||||
fmt.Println("cannot get current user: ", err)
|
||||
}
|
||||
os.Exit(2)
|
||||
}
|
||||
return user.Username
|
||||
}
|
||||
|
||||
// Prepare a sqlite or postgres connection string for testing.
|
||||
// Returns the connection string to use and a close function which must be called when the test finishes.
|
||||
// Calling this function twice will return the same database, which will have data from previous tests
|
||||
// unless close() is called.
|
||||
// TODO: namespace for concurrent package tests
|
||||
func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, close func()) {
|
||||
if dbType == DBTypeSQLite {
|
||||
// this will be made in the current working directory which namespaces concurrent package runs correctly
|
||||
dbname := "dendrite_test.db"
|
||||
return fmt.Sprintf("file:%s", dbname), func() {
|
||||
err := os.Remove(dbname)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to cleanup sqlite db '%s': %s", dbname, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Required vars: user and db
|
||||
// We'll try to infer from the local env if they are missing
|
||||
user := os.Getenv("POSTGRES_USER")
|
||||
if user == "" {
|
||||
user = currentUser()
|
||||
}
|
||||
connStr = fmt.Sprintf(
|
||||
"user=%s sslmode=disable",
|
||||
user,
|
||||
)
|
||||
// optional vars, used in CI
|
||||
password := os.Getenv("POSTGRES_PASSWORD")
|
||||
if password != "" {
|
||||
connStr += fmt.Sprintf(" password=%s", password)
|
||||
}
|
||||
host := os.Getenv("POSTGRES_HOST")
|
||||
if host != "" {
|
||||
connStr += fmt.Sprintf(" host=%s", host)
|
||||
}
|
||||
|
||||
// superuser database
|
||||
postgresDB := os.Getenv("POSTGRES_DB")
|
||||
// we cannot use 'dendrite_test' here else 2x concurrently running packages will try to use the same db.
|
||||
// instead, hash the current working directory, snaffle the first 16 bytes and append that to dendrite_test
|
||||
// and use that as the unique db name. We do this because packages are per-directory hence by hashing the
|
||||
// working (test) directory we ensure we get a consistent hash and don't hash against concurrent packages.
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatalf("cannot get working directory: %s", err)
|
||||
}
|
||||
hash := sha256.Sum256([]byte(wd))
|
||||
dbName := fmt.Sprintf("dendrite_test_%s", hex.EncodeToString(hash[:16]))
|
||||
if postgresDB == "" { // local server, use createdb
|
||||
createLocalDB(dbName)
|
||||
} else { // remote server, shell into the postgres user and CREATE DATABASE
|
||||
createRemoteDB(t, dbName, user, connStr)
|
||||
}
|
||||
connStr += fmt.Sprintf(" dbname=%s", dbName)
|
||||
|
||||
return connStr, func() {
|
||||
// Drop all tables on the database to get a fresh instance
|
||||
db, err := sql.Open("postgres", connStr)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to postgres db '%s': %s", connStr, err)
|
||||
}
|
||||
_, err = db.Exec(`DROP SCHEMA public CASCADE;
|
||||
CREATE SCHEMA public;`)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to cleanup postgres db '%s': %s", connStr, err)
|
||||
}
|
||||
_ = db.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Creates subtests with each known DBType
|
||||
func WithAllDatabases(t *testing.T, testFn func(t *testing.T, db DBType)) {
|
||||
dbs := map[string]DBType{
|
||||
"postgres": DBTypePostgres,
|
||||
"sqlite": DBTypeSQLite,
|
||||
}
|
||||
for dbName, dbType := range dbs {
|
||||
dbt := dbType
|
||||
t.Run(dbName, func(tt *testing.T) {
|
||||
tt.Parallel()
|
||||
testFn(tt, dbt)
|
||||
})
|
||||
}
|
||||
}
|
||||
90
test/event.go
Normal file
90
test/event.go
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ed25519"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type eventMods struct {
|
||||
originServerTS time.Time
|
||||
origin gomatrixserverlib.ServerName
|
||||
stateKey *string
|
||||
unsigned interface{}
|
||||
keyID gomatrixserverlib.KeyID
|
||||
privKey ed25519.PrivateKey
|
||||
}
|
||||
|
||||
type eventModifier func(e *eventMods)
|
||||
|
||||
func WithTimestamp(ts time.Time) eventModifier {
|
||||
return func(e *eventMods) {
|
||||
e.originServerTS = ts
|
||||
}
|
||||
}
|
||||
|
||||
func WithStateKey(skey string) eventModifier {
|
||||
return func(e *eventMods) {
|
||||
e.stateKey = &skey
|
||||
}
|
||||
}
|
||||
|
||||
func WithUnsigned(unsigned interface{}) eventModifier {
|
||||
return func(e *eventMods) {
|
||||
e.unsigned = unsigned
|
||||
}
|
||||
}
|
||||
|
||||
// Reverse a list of events
|
||||
func Reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
|
||||
out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
|
||||
for i := 0; i < len(in); i++ {
|
||||
out[i] = in[len(in)-i-1]
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func AssertEventIDsEqual(t *testing.T, gotEventIDs []string, wants []*gomatrixserverlib.HeaderedEvent) {
|
||||
t.Helper()
|
||||
if len(gotEventIDs) != len(wants) {
|
||||
t.Fatalf("length mismatch: got %d events, want %d", len(gotEventIDs), len(wants))
|
||||
}
|
||||
for i := range wants {
|
||||
w := wants[i].EventID()
|
||||
g := gotEventIDs[i]
|
||||
if w != g {
|
||||
t.Errorf("event at index %d mismatch:\ngot %s\n\nwant %s", i, string(g), string(w))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AssertEventsEqual(t *testing.T, gots, wants []*gomatrixserverlib.HeaderedEvent) {
|
||||
t.Helper()
|
||||
if len(gots) != len(wants) {
|
||||
t.Fatalf("length mismatch: got %d events, want %d", len(gots), len(wants))
|
||||
}
|
||||
for i := range wants {
|
||||
w := wants[i].JSON()
|
||||
g := gots[i].JSON()
|
||||
if !bytes.Equal(w, g) {
|
||||
t.Errorf("event at index %d mismatch:\ngot %s\n\nwant %s", i, string(g), string(w))
|
||||
}
|
||||
}
|
||||
}
|
||||
223
test/room.go
Normal file
223
test/room.go
Normal file
|
|
@ -0,0 +1,223 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package test
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type Preset int
|
||||
|
||||
var (
|
||||
PresetNone Preset = 0
|
||||
PresetPrivateChat Preset = 1
|
||||
PresetPublicChat Preset = 2
|
||||
PresetTrustedPrivateChat Preset = 3
|
||||
|
||||
roomIDCounter = int64(0)
|
||||
|
||||
testKeyID = gomatrixserverlib.KeyID("ed25519:test")
|
||||
testPrivateKey = ed25519.NewKeyFromSeed([]byte{
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
||||
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
|
||||
})
|
||||
)
|
||||
|
||||
type Room struct {
|
||||
ID string
|
||||
Version gomatrixserverlib.RoomVersion
|
||||
preset Preset
|
||||
creator *User
|
||||
|
||||
authEvents gomatrixserverlib.AuthEvents
|
||||
events []*gomatrixserverlib.HeaderedEvent
|
||||
}
|
||||
|
||||
// Create a new test room. Automatically creates the initial create events.
|
||||
func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room {
|
||||
t.Helper()
|
||||
counter := atomic.AddInt64(&roomIDCounter, 1)
|
||||
|
||||
// set defaults then let roomModifiers override
|
||||
r := &Room{
|
||||
ID: fmt.Sprintf("!%d:localhost", counter),
|
||||
creator: creator,
|
||||
authEvents: gomatrixserverlib.NewAuthEvents(nil),
|
||||
preset: PresetPublicChat,
|
||||
Version: gomatrixserverlib.RoomVersionV9,
|
||||
}
|
||||
for _, m := range modifiers {
|
||||
m(t, r)
|
||||
}
|
||||
r.insertCreateEvents(t)
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Room) insertCreateEvents(t *testing.T) {
|
||||
t.Helper()
|
||||
var joinRule gomatrixserverlib.JoinRuleContent
|
||||
var hisVis gomatrixserverlib.HistoryVisibilityContent
|
||||
plContent := eventutil.InitialPowerLevelsContent(r.creator.ID)
|
||||
switch r.preset {
|
||||
case PresetTrustedPrivateChat:
|
||||
fallthrough
|
||||
case PresetPrivateChat:
|
||||
joinRule.JoinRule = "invite"
|
||||
hisVis.HistoryVisibility = "shared"
|
||||
case PresetPublicChat:
|
||||
joinRule.JoinRule = "public"
|
||||
hisVis.HistoryVisibility = "shared"
|
||||
}
|
||||
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomCreate, map[string]interface{}{
|
||||
"creator": r.creator.ID,
|
||||
"room_version": r.Version,
|
||||
}, WithStateKey(""))
|
||||
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "join",
|
||||
}, WithStateKey(r.creator.ID))
|
||||
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey(""))
|
||||
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey(""))
|
||||
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey(""))
|
||||
}
|
||||
|
||||
// Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe.
|
||||
func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, content interface{}, mods ...eventModifier) *gomatrixserverlib.HeaderedEvent {
|
||||
t.Helper()
|
||||
depth := 1 + len(r.events) // depth starts at 1
|
||||
|
||||
// possible event modifiers (optional fields)
|
||||
mod := &eventMods{}
|
||||
for _, m := range mods {
|
||||
m(mod)
|
||||
}
|
||||
|
||||
if mod.privKey == nil {
|
||||
mod.privKey = testPrivateKey
|
||||
}
|
||||
if mod.keyID == "" {
|
||||
mod.keyID = testKeyID
|
||||
}
|
||||
if mod.originServerTS.IsZero() {
|
||||
mod.originServerTS = time.Now()
|
||||
}
|
||||
if mod.origin == "" {
|
||||
mod.origin = gomatrixserverlib.ServerName("localhost")
|
||||
}
|
||||
|
||||
var unsigned gomatrixserverlib.RawJSON
|
||||
var err error
|
||||
if mod.unsigned != nil {
|
||||
unsigned, err = json.Marshal(mod.unsigned)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateEvent[%s]: failed to marshal unsigned field: %s", eventType, err)
|
||||
}
|
||||
}
|
||||
|
||||
builder := &gomatrixserverlib.EventBuilder{
|
||||
Sender: creator.ID,
|
||||
RoomID: r.ID,
|
||||
Type: eventType,
|
||||
StateKey: mod.stateKey,
|
||||
Depth: int64(depth),
|
||||
Unsigned: unsigned,
|
||||
}
|
||||
err = builder.SetContent(content)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateEvent[%s]: failed to SetContent: %s", eventType, err)
|
||||
}
|
||||
if depth > 1 {
|
||||
builder.PrevEvents = []gomatrixserverlib.EventReference{r.events[len(r.events)-1].EventReference()}
|
||||
}
|
||||
|
||||
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateEvent[%s]: failed to StateNeededForEventBuilder: %s", eventType, err)
|
||||
}
|
||||
refs, err := eventsNeeded.AuthEventReferences(&r.authEvents)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateEvent[%s]: failed to AuthEventReferences: %s", eventType, err)
|
||||
}
|
||||
builder.AuthEvents = refs
|
||||
ev, err := builder.Build(
|
||||
mod.originServerTS, mod.origin, mod.keyID,
|
||||
mod.privKey, r.Version,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateEvent[%s]: failed to build event: %s", eventType, err)
|
||||
}
|
||||
if err = gomatrixserverlib.Allowed(ev, &r.authEvents); err != nil {
|
||||
t.Fatalf("CreateEvent[%s]: failed to verify event was allowed: %s", eventType, err)
|
||||
}
|
||||
return ev.Headered(r.Version)
|
||||
}
|
||||
|
||||
// Add a new event to this room DAG. Not thread-safe.
|
||||
func (r *Room) InsertEvent(t *testing.T, he *gomatrixserverlib.HeaderedEvent) {
|
||||
t.Helper()
|
||||
// Add the event to the list of auth events
|
||||
r.events = append(r.events, he)
|
||||
if he.StateKey() != nil {
|
||||
err := r.authEvents.AddEvent(he.Unwrap())
|
||||
if err != nil {
|
||||
t.Fatalf("InsertEvent: failed to add event to auth events: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Room) Events() []*gomatrixserverlib.HeaderedEvent {
|
||||
return r.events
|
||||
}
|
||||
|
||||
func (r *Room) CreateAndInsert(t *testing.T, creator *User, eventType string, content interface{}, mods ...eventModifier) *gomatrixserverlib.HeaderedEvent {
|
||||
t.Helper()
|
||||
he := r.CreateEvent(t, creator, eventType, content, mods...)
|
||||
r.InsertEvent(t, he)
|
||||
return he
|
||||
}
|
||||
|
||||
// All room modifiers are below
|
||||
|
||||
type roomModifier func(t *testing.T, r *Room)
|
||||
|
||||
func RoomPreset(p Preset) roomModifier {
|
||||
return func(t *testing.T, r *Room) {
|
||||
switch p {
|
||||
case PresetPrivateChat:
|
||||
fallthrough
|
||||
case PresetPublicChat:
|
||||
fallthrough
|
||||
case PresetTrustedPrivateChat:
|
||||
fallthrough
|
||||
case PresetNone:
|
||||
r.preset = p
|
||||
default:
|
||||
t.Errorf("invalid RoomPreset: %v", p)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier {
|
||||
return func(t *testing.T, r *Room) {
|
||||
r.Version = ver
|
||||
}
|
||||
}
|
||||
|
|
@ -1,5 +1,4 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
|
|
@ -13,24 +12,25 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
package test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type statements struct {
|
||||
media mediaStatements
|
||||
thumbnail thumbnailStatements
|
||||
var (
|
||||
userIDCounter = int64(0)
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID string
|
||||
}
|
||||
|
||||
func (s *statements) prepare(db *sql.DB) (err error) {
|
||||
if err = s.media.prepare(db); err != nil {
|
||||
return
|
||||
func NewUser() *User {
|
||||
counter := atomic.AddInt64(&userIDCounter, 1)
|
||||
u := &User{
|
||||
ID: fmt.Sprintf("@%d:localhost", counter),
|
||||
}
|
||||
if err = s.thumbnail.prepare(db); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
return u
|
||||
}
|
||||
|
|
@ -492,16 +492,16 @@ type PerformPusherDeletionRequest struct {
|
|||
|
||||
// Pusher represents a push notification subscriber
|
||||
type Pusher struct {
|
||||
SessionID int64 `json:"session_id,omitempty"`
|
||||
PushKey string `json:"pushkey"`
|
||||
PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"`
|
||||
Kind PusherKind `json:"kind"`
|
||||
AppID string `json:"app_id"`
|
||||
AppDisplayName string `json:"app_display_name"`
|
||||
DeviceDisplayName string `json:"device_display_name"`
|
||||
ProfileTag string `json:"profile_tag"`
|
||||
Language string `json:"lang"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
SessionID int64 `json:"session_id,omitempty"`
|
||||
PushKey string `json:"pushkey"`
|
||||
PushKeyTS int64 `json:"pushkey_ts,omitempty"`
|
||||
Kind PusherKind `json:"kind"`
|
||||
AppID string `json:"app_id"`
|
||||
AppDisplayName string `json:"app_display_name"`
|
||||
DeviceDisplayName string `json:"device_display_name"`
|
||||
ProfileTag string `json:"profile_tag"`
|
||||
Language string `json:"lang"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
}
|
||||
|
||||
type PusherKind string
|
||||
|
|
|
|||
|
|
@ -653,7 +653,7 @@ func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.Perform
|
|||
return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart)
|
||||
}
|
||||
if req.Pusher.PushKeyTS == 0 {
|
||||
req.Pusher.PushKeyTS = gomatrixserverlib.AsTimestamp(time.Now())
|
||||
req.Pusher.PushKeyTS = int64(time.Now().Unix())
|
||||
}
|
||||
return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -369,6 +369,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
|||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(QueryAccountAvailabilityPath,
|
||||
httputil.MakeInternalAPI("queryAccountAvailability", func(req *http.Request) util.JSONResponse {
|
||||
request := api.QueryAccountAvailabilityRequest{}
|
||||
response := api.QueryAccountAvailabilityResponse{}
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if err := s.QueryAccountAvailability(req.Context(), &request, &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(QueryAccountByPasswordPath,
|
||||
httputil.MakeInternalAPI("queryAccountByPassword", func(req *http.Request) util.JSONResponse {
|
||||
request := api.QueryAccountByPasswordRequest{}
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
|
@ -95,7 +94,7 @@ type pushersStatements struct {
|
|||
// Returns nil error success.
|
||||
func (s *pushersStatements) InsertPusher(
|
||||
ctx context.Context, txn *sql.Tx, session_id int64,
|
||||
pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
|
||||
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
||||
logrus.Debugf("Created pusher %d", session_id)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
|
@ -95,7 +94,7 @@ type pushersStatements struct {
|
|||
// Returns nil error success.
|
||||
func (s *pushersStatements) InsertPusher(
|
||||
ctx context.Context, txn *sql.Tx, session_id int64,
|
||||
pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
|
||||
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
|
||||
) error {
|
||||
_, err := s.insertPusherStmt.ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
||||
logrus.Debugf("Created pusher %d", session_id)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type AccountDataTable interface {
|
||||
|
|
@ -96,7 +95,7 @@ type ThreePIDTable interface {
|
|||
}
|
||||
|
||||
type PusherTable interface {
|
||||
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error
|
||||
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error
|
||||
SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error)
|
||||
DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error
|
||||
DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error
|
||||
|
|
|
|||
Loading…
Reference in a new issue