Merge branch 'main' into neilalexander/removelibp2p

This commit is contained in:
Neil Alexander 2022-04-21 11:22:46 +01:00
commit 786b89dcd5
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
68 changed files with 1397 additions and 977 deletions

View file

@ -52,6 +52,7 @@ import (
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/h2c" "golang.org/x/net/http2/h2c"
pineconeConnections "github.com/matrix-org/pinecone/connections"
pineconeMulticast "github.com/matrix-org/pinecone/multicast" pineconeMulticast "github.com/matrix-org/pinecone/multicast"
pineconeRouter "github.com/matrix-org/pinecone/router" pineconeRouter "github.com/matrix-org/pinecone/router"
pineconeSessions "github.com/matrix-org/pinecone/sessions" pineconeSessions "github.com/matrix-org/pinecone/sessions"
@ -71,11 +72,9 @@ type DendriteMonolith struct {
PineconeRouter *pineconeRouter.Router PineconeRouter *pineconeRouter.Router
PineconeMulticast *pineconeMulticast.Multicast PineconeMulticast *pineconeMulticast.Multicast
PineconeQUIC *pineconeSessions.Sessions PineconeQUIC *pineconeSessions.Sessions
PineconeManager *pineconeConnections.ConnectionManager
StorageDirectory string StorageDirectory string
CacheDirectory string CacheDirectory string
staticPeerURI string
staticPeerMutex sync.RWMutex
staticPeerAttempt chan struct{}
listener net.Listener listener net.Listener
httpServer *http.Server httpServer *http.Server
processContext *process.ProcessContext processContext *process.ProcessContext
@ -104,15 +103,8 @@ func (m *DendriteMonolith) SetMulticastEnabled(enabled bool) {
} }
func (m *DendriteMonolith) SetStaticPeer(uri string) { func (m *DendriteMonolith) SetStaticPeer(uri string) {
m.staticPeerMutex.Lock() m.PineconeManager.RemovePeers()
m.staticPeerURI = strings.TrimSpace(uri) m.PineconeManager.AddPeer(strings.TrimSpace(uri))
m.staticPeerMutex.Unlock()
m.DisconnectType(int(pineconeRouter.PeerTypeRemote))
if uri != "" {
go func() {
m.staticPeerAttempt <- struct{}{}
}()
}
} }
func (m *DendriteMonolith) DisconnectType(peertype int) { func (m *DendriteMonolith) DisconnectType(peertype int) {
@ -210,43 +202,6 @@ func (m *DendriteMonolith) RegisterDevice(localpart, deviceID string) (string, e
return loginRes.Device.AccessToken, nil return loginRes.Device.AccessToken, nil
} }
func (m *DendriteMonolith) staticPeerConnect() {
connected := map[string]bool{} // URI -> connected?
attempt := func() {
m.staticPeerMutex.RLock()
uri := m.staticPeerURI
m.staticPeerMutex.RUnlock()
if uri == "" {
return
}
for k := range connected {
delete(connected, k)
}
for _, uri := range strings.Split(uri, ",") {
connected[strings.TrimSpace(uri)] = false
}
for _, info := range m.PineconeRouter.Peers() {
connected[info.URI] = true
}
for k, online := range connected {
if !online {
if err := conn.ConnectToPeer(m.PineconeRouter, k); err != nil {
logrus.WithError(err).Error("Failed to connect to static peer")
}
}
}
}
for {
select {
case <-m.processContext.Context().Done():
case <-m.staticPeerAttempt:
attempt()
case <-time.After(time.Second * 5):
attempt()
}
}
}
// nolint:gocyclo // nolint:gocyclo
func (m *DendriteMonolith) Start() { func (m *DendriteMonolith) Start() {
var err error var err error
@ -284,6 +239,7 @@ func (m *DendriteMonolith) Start() {
m.PineconeRouter = pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false) m.PineconeRouter = pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
m.PineconeQUIC = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), m.PineconeRouter, []string{"matrix"}) m.PineconeQUIC = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), m.PineconeRouter, []string{"matrix"})
m.PineconeMulticast = pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), m.PineconeRouter) m.PineconeMulticast = pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), m.PineconeRouter)
m.PineconeManager = pineconeConnections.NewConnectionManager(m.PineconeRouter)
prefix := hex.EncodeToString(pk) prefix := hex.EncodeToString(pk)
cfg := &config.Dendrite{} cfg := &config.Dendrite{}
@ -392,9 +348,6 @@ func (m *DendriteMonolith) Start() {
m.processContext = base.ProcessContext m.processContext = base.ProcessContext
m.staticPeerAttempt = make(chan struct{}, 1)
go m.staticPeerConnect()
go func() { go func() {
m.logger.Info("Listening on ", cfg.Global.ServerName) m.logger.Info("Listening on ", cfg.Global.ServerName)
m.logger.Fatal(m.httpServer.Serve(m.PineconeQUIC.Protocol("matrix"))) m.logger.Fatal(m.httpServer.Serve(m.PineconeQUIC.Protocol("matrix")))

View file

@ -52,6 +52,7 @@ func RequestTurnServer(req *http.Request, device *api.Device, cfg *config.Client
if turnConfig.SharedSecret != "" { if turnConfig.SharedSecret != "" {
expiry := time.Now().Add(duration).Unix() expiry := time.Now().Add(duration).Unix()
resp.Username = fmt.Sprintf("%d:%s", expiry, device.UserID)
mac := hmac.New(sha1.New, []byte(turnConfig.SharedSecret)) mac := hmac.New(sha1.New, []byte(turnConfig.SharedSecret))
_, err := mac.Write([]byte(resp.Username)) _, err := mac.Write([]byte(resp.Username))
@ -60,7 +61,6 @@ func RequestTurnServer(req *http.Request, device *api.Device, cfg *config.Client
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
resp.Username = fmt.Sprintf("%d:%s", expiry, device.UserID)
resp.Password = base64.StdEncoding.EncodeToString(mac.Sum(nil)) resp.Password = base64.StdEncoding.EncodeToString(mac.Sum(nil))
} else if turnConfig.Username != "" && turnConfig.Password != "" { } else if turnConfig.Username != "" && turnConfig.Password != "" {
resp.Username = turnConfig.Username resp.Username = turnConfig.Username

View file

@ -25,7 +25,6 @@ import (
"net" "net"
"net/http" "net/http"
"os" "os"
"strings"
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -47,6 +46,7 @@ import (
"github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
pineconeConnections "github.com/matrix-org/pinecone/connections"
pineconeMulticast "github.com/matrix-org/pinecone/multicast" pineconeMulticast "github.com/matrix-org/pinecone/multicast"
pineconeRouter "github.com/matrix-org/pinecone/router" pineconeRouter "github.com/matrix-org/pinecone/router"
pineconeSessions "github.com/matrix-org/pinecone/sessions" pineconeSessions "github.com/matrix-org/pinecone/sessions"
@ -90,6 +90,13 @@ func main() {
} }
pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false) pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter)
pManager := pineconeConnections.NewConnectionManager(pRouter)
pMulticast.Start()
if instancePeer != nil && *instancePeer != "" {
pManager.AddPeer(*instancePeer)
}
go func() { go func() {
listener, err := net.Listen("tcp", *instanceListen) listener, err := net.Listen("tcp", *instanceListen)
@ -119,36 +126,6 @@ func main() {
} }
}() }()
pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter)
pMulticast.Start()
connectToStaticPeer := func() {
connected := map[string]bool{} // URI -> connected?
for _, uri := range strings.Split(*instancePeer, ",") {
connected[strings.TrimSpace(uri)] = false
}
attempt := func() {
for k := range connected {
connected[k] = false
}
for _, info := range pRouter.Peers() {
connected[info.URI] = true
}
for k, online := range connected {
if !online {
if err := conn.ConnectToPeer(pRouter, k); err != nil {
logrus.WithError(err).Error("Failed to connect to static peer")
}
}
}
}
for {
attempt()
time.Sleep(time.Second * 5)
}
}
cfg := &config.Dendrite{} cfg := &config.Dendrite{}
cfg.Defaults(true) cfg.Defaults(true)
cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk))
@ -268,7 +245,6 @@ func main() {
Handler: pMux, Handler: pMux,
} }
go connectToStaticPeer()
go func() { go func() {
pubkey := pRouter.PublicKey() pubkey := pRouter.PublicKey()
logrus.Info("Listening on ", hex.EncodeToString(pubkey[:])) logrus.Info("Listening on ", hex.EncodeToString(pubkey[:]))

View file

@ -22,7 +22,6 @@ import (
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"syscall/js" "syscall/js"
"time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/appservice"
@ -44,6 +43,7 @@ import (
_ "github.com/matrix-org/go-sqlite3-js" _ "github.com/matrix-org/go-sqlite3-js"
pineconeConnections "github.com/matrix-org/pinecone/connections"
pineconeRouter "github.com/matrix-org/pinecone/router" pineconeRouter "github.com/matrix-org/pinecone/router"
pineconeSessions "github.com/matrix-org/pinecone/sessions" pineconeSessions "github.com/matrix-org/pinecone/sessions"
) )
@ -154,6 +154,8 @@ func startup() {
pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false) pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
pSessions := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"}) pSessions := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
pManager := pineconeConnections.NewConnectionManager(pRouter)
pManager.AddPeer("wss://pinecone.matrix.org/public")
cfg := &config.Dendrite{} cfg := &config.Dendrite{}
cfg.Defaults(true) cfg.Defaults(true)
@ -237,20 +239,4 @@ func startup() {
} }
s.ListenAndServe("fetch") s.ListenAndServe("fetch")
}() }()
// Connect to the static peer
go func() {
for {
if pRouter.PeerCount(pineconeRouter.PeerTypeRemote) == 0 {
if err := conn.ConnectToPeer(pRouter, publicPeer); err != nil {
logrus.WithError(err).Error("Failed to connect to static peer")
}
}
select {
case <-base.ProcessContext.Context().Done():
return
case <-time.After(time.Second * 5):
}
}
}()
} }

13
go.mod
View file

@ -1,8 +1,8 @@
module github.com/matrix-org/dendrite module github.com/matrix-org/dendrite
replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.7.5-0.20220311134712-e2e4a244f30e replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.8.1-0.20220419100629-2278c94774f9
replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.13.1-0.20220419101051-b262d9f0be1e
require ( require (
github.com/Arceliar/ironwood v0.0.0-20211125050254-8951369625d0 github.com/Arceliar/ironwood v0.0.0-20211125050254-8951369625d0
@ -27,15 +27,15 @@ require (
github.com/juju/testing v0.0.0-20220203020004-a0ff61f03494 // indirect github.com/juju/testing v0.0.0-20220203020004-a0ff61f03494 // indirect
github.com/lib/pq v1.10.5 github.com/lib/pq v1.10.5
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5 github.com/matrix-org/gomatrixserverlib v0.0.0-20220408160933-cf558306b56f
github.com/matrix-org/pinecone v0.0.0-20220404141326-e526fa82f79d github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.10 github.com/mattn/go-sqlite3 v1.14.10
github.com/miekg/dns v1.1.31 // indirect github.com/miekg/dns v1.1.31 // indirect
github.com/nats-io/nats-server/v2 v2.7.4-0.20220309205833-773636c1c5bb github.com/nats-io/nats-server/v2 v2.7.4-0.20220309205833-773636c1c5bb
github.com/nats-io/nats.go v1.13.1-0.20220308171302-2f2f6968e98d github.com/nats-io/nats.go v1.14.0
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
github.com/ngrok/sqlmw v0.0.0-20211220175533-9d16fdc47b31 github.com/ngrok/sqlmw v0.0.0-20211220175533-9d16fdc47b31
@ -56,6 +56,7 @@ require (
golang.org/x/image v0.0.0-20220321031419-a8550c1d254a golang.org/x/image v0.0.0-20220321031419-a8550c1d254a
golang.org/x/mobile v0.0.0-20220407111146-e579adbbc4a2 golang.org/x/mobile v0.0.0-20220407111146-e579adbbc4a2
golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3 golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3
golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 // indirect
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 golang.org/x/term v0.0.0-20210927222741-03fcf44c2211
gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/h2non/bimg.v1 v1.1.9
gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v2 v2.4.0

27
go.sum
View file

@ -789,15 +789,15 @@ github.com/masterzen/winrm v0.0.0-20161014151040-7a535cd943fc/go.mod h1:CfZSN7zw
github.com/masterzen/xmlpath v0.0.0-20140218185901-13f4951698ad/go.mod h1:A0zPC53iKKKcXYxr4ROjpQRQ5FgJXtelNdSmHHuq/tY= github.com/masterzen/xmlpath v0.0.0-20140218185901-13f4951698ad/go.mod h1:A0zPC53iKKKcXYxr4ROjpQRQ5FgJXtelNdSmHHuq/tY=
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e h1:DP5RC0Z3XdyBEW5dKt8YPeN6vZbm6OzVaGVp7f1BQRM= github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e h1:DP5RC0Z3XdyBEW5dKt8YPeN6vZbm6OzVaGVp7f1BQRM=
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e/go.mod h1:NgPCr+UavRGH6n5jmdX8DuqFZ4JiCWIJoZiuhTRLSUg= github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e/go.mod h1:NgPCr+UavRGH6n5jmdX8DuqFZ4JiCWIJoZiuhTRLSUg=
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d h1:mGhPVaTht5NViFN/UpdrIlRApmH2FWcVaKUH5MdBKiY= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw2QV3YD/fRrzEDPNGgTlJlvXY0EHHnT87wF3OA=
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5 h1:Fkennny7+Z/5pygrhjFMZbz1j++P2hhhLoT7NO3p8DQ= github.com/matrix-org/gomatrixserverlib v0.0.0-20220408160933-cf558306b56f h1:MZrl4TgTnlaOn2Cu9gJCoJ3oyW5mT4/3QIZGgZXzKl4=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5/go.mod h1:V5eO8rn/C3rcxig37A/BCeKerLFS+9Avg/77FIeTZ48= github.com/matrix-org/gomatrixserverlib v0.0.0-20220408160933-cf558306b56f/go.mod h1:V5eO8rn/C3rcxig37A/BCeKerLFS+9Avg/77FIeTZ48=
github.com/matrix-org/pinecone v0.0.0-20220404141326-e526fa82f79d h1:1+T4eOPRsf6cr0lMPW4oO2k8TTHm4mqIh65kpEID5Rk= github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48 h1:W0sjjC6yjskHX4mb0nk3p0fXAlbU5bAFUFeEtlrPASE=
github.com/matrix-org/pinecone v0.0.0-20220404141326-e526fa82f79d/go.mod h1:ulJzsVOTssIVp1j/m5eI//4VpAGDkMt5NrRuAVX7wpc= github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48/go.mod h1:ulJzsVOTssIVp1j/m5eI//4VpAGDkMt5NrRuAVX7wpc=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
@ -878,8 +878,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8m
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
github.com/nats-io/jwt/v2 v2.2.1-0.20220113022732-58e87895b296 h1:vU9tpM3apjYlLLeY23zRWJ9Zktr5jp+mloR942LEOpY= github.com/nats-io/jwt/v2 v2.2.1-0.20220330180145-442af02fd36a h1:lem6QCvxR0Y28gth9P+wV2K/zYUUAkJ+55U8cpS0p5I=
github.com/nats-io/jwt/v2 v2.2.1-0.20220113022732-58e87895b296/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= github.com/nats-io/jwt/v2 v2.2.1-0.20220330180145-442af02fd36a/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k=
github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8=
github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4=
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
@ -888,10 +888,10 @@ github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uY
github.com/ncw/swift v1.0.47/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM= github.com/ncw/swift v1.0.47/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM=
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo=
github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
github.com/neilalexander/nats-server/v2 v2.7.5-0.20220311134712-e2e4a244f30e h1:5tEHLzvDeS6IeqO2o9FFhsE3V2erYj8FlMt2J91wzsk= github.com/neilalexander/nats-server/v2 v2.8.1-0.20220419100629-2278c94774f9 h1:VGU5HYAwy8LRbSkrT+kCHvujVmwK8Aa/vc1O+eReTbM=
github.com/neilalexander/nats-server/v2 v2.7.5-0.20220311134712-e2e4a244f30e/go.mod h1:1vZ2Nijh8tcyNe8BDVyTviCd9NYzRbubQYiEHsvOQWc= github.com/neilalexander/nats-server/v2 v2.8.1-0.20220419100629-2278c94774f9/go.mod h1:5vic7C58BFEVltiZhs7Kq81q2WcEPhJPsmNv1FOrdv0=
github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c h1:G2qsv7D0rY94HAu8pXmElMluuMHQ85waxIDQBhIzV2Q= github.com/neilalexander/nats.go v1.13.1-0.20220419101051-b262d9f0be1e h1:kNIzIzj2OvnlreA+sTJ12nWJzTP3OSLNKDL/Iq9mF6Y=
github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w= github.com/neilalexander/nats.go v1.13.1-0.20220419101051-b262d9f0be1e/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w=
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 h1:lrVQzBtkeQEGGYUHwSX1XPe1E5GL6U3KYCNe2G4bncQ= github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 h1:lrVQzBtkeQEGGYUHwSX1XPe1E5GL6U3KYCNe2G4bncQ=
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8= github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
@ -1541,8 +1541,9 @@ golang.org/x/sys v0.0.0-20211116061358-0a5406a5449c/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64 h1:D1v9ucDTYBtbz5vNuBbAhIMAGhQhJ6Ym5ah3maMVNX4=
golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 h1:QyVthZKMsyaQwBTJE04jdNN0Pp5Fn9Qga0mrgxyERQM=
golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=

View file

@ -0,0 +1,86 @@
package caching
import (
"fmt"
"time"
userapi "github.com/matrix-org/dendrite/userapi/api"
)
const (
LazyLoadCacheName = "lazy_load_members"
LazyLoadCacheMaxEntries = 128
LazyLoadCacheMaxUserEntries = 128
LazyLoadCacheMutable = true
LazyLoadCacheMaxAge = time.Minute * 30
)
type LazyLoadCache struct {
// InMemoryLRUCachePartition containing other InMemoryLRUCachePartitions
// with the actual cached members
userCaches *InMemoryLRUCachePartition
}
// NewLazyLoadCache creates a new LazyLoadCache.
func NewLazyLoadCache() (*LazyLoadCache, error) {
cache, err := NewInMemoryLRUCachePartition(
LazyLoadCacheName,
LazyLoadCacheMutable,
LazyLoadCacheMaxEntries,
LazyLoadCacheMaxAge,
true,
)
if err != nil {
return nil, err
}
go cacheCleaner(cache)
return &LazyLoadCache{
userCaches: cache,
}, nil
}
func (c *LazyLoadCache) lazyLoadCacheForUser(device *userapi.Device) (*InMemoryLRUCachePartition, error) {
cacheName := fmt.Sprintf("%s/%s", device.UserID, device.ID)
userCache, ok := c.userCaches.Get(cacheName)
if ok && userCache != nil {
if cache, ok := userCache.(*InMemoryLRUCachePartition); ok {
return cache, nil
}
}
cache, err := NewInMemoryLRUCachePartition(
LazyLoadCacheName,
LazyLoadCacheMutable,
LazyLoadCacheMaxUserEntries,
LazyLoadCacheMaxAge,
false,
)
if err != nil {
return nil, err
}
c.userCaches.Set(cacheName, cache)
go cacheCleaner(cache)
return cache, nil
}
func (c *LazyLoadCache) StoreLazyLoadedUser(device *userapi.Device, roomID, userID, eventID string) {
cache, err := c.lazyLoadCacheForUser(device)
if err != nil {
return
}
cacheKey := fmt.Sprintf("%s/%s/%s/%s", device.UserID, device.ID, roomID, userID)
cache.Set(cacheKey, eventID)
}
func (c *LazyLoadCache) IsLazyLoadedUserCached(device *userapi.Device, roomID, userID string) (string, bool) {
cache, err := c.lazyLoadCacheForUser(device)
if err != nil {
return "", false
}
cacheKey := fmt.Sprintf("%s/%s/%s/%s", device.UserID, device.ID, roomID, userID)
val, ok := cache.Get(cacheKey)
if !ok {
return "", ok
}
return val.(string), ok
}

View file

@ -3,8 +3,6 @@ package pushgateway
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"github.com/matrix-org/gomatrixserverlib"
) )
// A Client is how interactions with a Push Gateway is done. // A Client is how interactions with a Push Gateway is done.
@ -47,11 +45,11 @@ type Counts struct {
} }
type Device struct { type Device struct {
AppID string `json:"app_id"` // Required AppID string `json:"app_id"` // Required
Data map[string]interface{} `json:"data"` // Required. UNSPEC: Sytests require this to allow unknown keys. Data map[string]interface{} `json:"data"` // Required. UNSPEC: Sytests require this to allow unknown keys.
PushKey string `json:"pushkey"` // Required PushKey string `json:"pushkey"` // Required
PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"` PushKeyTS int64 `json:"pushkey_ts,omitempty"`
Tweaks map[string]interface{} `json:"tweaks,omitempty"` Tweaks map[string]interface{} `json:"tweaks,omitempty"`
} }
type Prio string type Prio string

View file

@ -32,7 +32,7 @@ func AddPublicRoutes(
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
client *gomatrixserverlib.Client, client *gomatrixserverlib.Client,
) { ) {
mediaDB, err := storage.Open(&cfg.Database) mediaDB, err := storage.NewMediaAPIDatasource(&cfg.Database)
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to connect to media db") logrus.WithError(err).Panicf("failed to connect to media db")
} }

View file

@ -22,6 +22,7 @@ import (
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"os"
"path" "path"
"strings" "strings"
@ -311,6 +312,26 @@ func (r *uploadRequest) storeFileAndMetadata(
} }
go func() { go func() {
file, err := os.Open(string(finalPath))
if err != nil {
r.Logger.WithError(err).Error("unable to open file")
return
}
defer file.Close() // nolint: errcheck
// http.DetectContentType only needs 512 bytes
buf := make([]byte, 512)
_, err = file.Read(buf)
if err != nil {
r.Logger.WithError(err).Error("unable to read file")
return
}
// Check if we need to generate thumbnails
fileType := http.DetectContentType(buf)
if !strings.HasPrefix(fileType, "image") {
r.Logger.WithField("contentType", fileType).Debugf("uploaded file is not an image or can not be thumbnailed, not generating thumbnails")
return
}
busy, err := thumbnailer.GenerateThumbnails( busy, err := thumbnailer.GenerateThumbnails(
context.Background(), finalPath, thumbnailSizes, r.MediaMetadata, context.Background(), finalPath, thumbnailSizes, r.MediaMetadata,
activeThumbnailGeneration, maxThumbnailGenerators, db, r.Logger, activeThumbnailGeneration, maxThumbnailGenerators, db, r.Logger,

View file

@ -51,7 +51,7 @@ func Test_uploadRequest_doUpload(t *testing.T) {
_ = os.Mkdir(testdataPath, os.ModePerm) _ = os.Mkdir(testdataPath, os.ModePerm)
defer fileutils.RemoveDir(types.Path(testdataPath), nil) defer fileutils.RemoveDir(types.Path(testdataPath), nil)
db, err := storage.Open(&config.DatabaseOptions{ db, err := storage.NewMediaAPIDatasource(&config.DatabaseOptions{
ConnectionString: "file::memory:?cache=shared", ConnectionString: "file::memory:?cache=shared",
MaxOpenConnections: 100, MaxOpenConnections: 100,
MaxIdleConnections: 2, MaxIdleConnections: 2,

View file

@ -22,9 +22,17 @@ import (
) )
type Database interface { type Database interface {
MediaRepository
Thumbnails
}
type MediaRepository interface {
StoreMediaMetadata(ctx context.Context, mediaMetadata *types.MediaMetadata) error StoreMediaMetadata(ctx context.Context, mediaMetadata *types.MediaMetadata) error
GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error)
GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error)
}
type Thumbnails interface {
StoreThumbnail(ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata) error StoreThumbnail(ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata) error
GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error) GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error)
GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) ([]*types.ThumbnailMetadata, error) GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) ([]*types.ThumbnailMetadata, error)

View file

@ -20,6 +20,8 @@ import (
"database/sql" "database/sql"
"time" "time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
"github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -69,24 +71,25 @@ type mediaStatements struct {
selectMediaByHashStmt *sql.Stmt selectMediaByHashStmt *sql.Stmt
} }
func (s *mediaStatements) prepare(db *sql.DB) (err error) { func NewPostgresMediaRepositoryTable(db *sql.DB) (tables.MediaRepository, error) {
_, err = db.Exec(mediaSchema) s := &mediaStatements{}
_, err := db.Exec(mediaSchema)
if err != nil { if err != nil {
return return nil, err
} }
return statementList{ return s, sqlutil.StatementList{
{&s.insertMediaStmt, insertMediaSQL}, {&s.insertMediaStmt, insertMediaSQL},
{&s.selectMediaStmt, selectMediaSQL}, {&s.selectMediaStmt, selectMediaSQL},
{&s.selectMediaByHashStmt, selectMediaByHashSQL}, {&s.selectMediaByHashStmt, selectMediaByHashSQL},
}.prepare(db) }.Prepare(db)
} }
func (s *mediaStatements) insertMedia( func (s *mediaStatements) InsertMedia(
ctx context.Context, mediaMetadata *types.MediaMetadata, ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata,
) error { ) error {
mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000) mediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
_, err := s.insertMediaStmt.ExecContext( _, err := sqlutil.TxStmtContext(ctx, txn, s.insertMediaStmt).ExecContext(
ctx, ctx,
mediaMetadata.MediaID, mediaMetadata.MediaID,
mediaMetadata.Origin, mediaMetadata.Origin,
@ -100,14 +103,14 @@ func (s *mediaStatements) insertMedia(
return err return err
} }
func (s *mediaStatements) selectMedia( func (s *mediaStatements) SelectMedia(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) { ) (*types.MediaMetadata, error) {
mediaMetadata := types.MediaMetadata{ mediaMetadata := types.MediaMetadata{
MediaID: mediaID, MediaID: mediaID,
Origin: mediaOrigin, Origin: mediaOrigin,
} }
err := s.selectMediaStmt.QueryRowContext( err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaStmt).QueryRowContext(
ctx, mediaMetadata.MediaID, mediaMetadata.Origin, ctx, mediaMetadata.MediaID, mediaMetadata.Origin,
).Scan( ).Scan(
&mediaMetadata.ContentType, &mediaMetadata.ContentType,
@ -120,14 +123,14 @@ func (s *mediaStatements) selectMedia(
return &mediaMetadata, err return &mediaMetadata, err
} }
func (s *mediaStatements) selectMediaByHash( func (s *mediaStatements) SelectMediaByHash(
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) { ) (*types.MediaMetadata, error) {
mediaMetadata := types.MediaMetadata{ mediaMetadata := types.MediaMetadata{
Base64Hash: mediaHash, Base64Hash: mediaHash,
Origin: mediaOrigin, Origin: mediaOrigin,
} }
err := s.selectMediaStmt.QueryRowContext( err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaByHashStmt).QueryRowContext(
ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin, ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin,
).Scan( ).Scan(
&mediaMetadata.ContentType, &mediaMetadata.ContentType,

View file

@ -0,0 +1,46 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
// Import the postgres database driver.
_ "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/shared"
"github.com/matrix-org/dendrite/setup/config"
)
// NewDatabase opens a postgres database.
func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
mediaRepo, err := NewPostgresMediaRepositoryTable(db)
if err != nil {
return nil, err
}
thumbnails, err := NewPostgresThumbnailsTable(db)
if err != nil {
return nil, err
}
return &shared.Database{
MediaRepository: mediaRepo,
Thumbnails: thumbnails,
DB: db,
Writer: sqlutil.NewExclusiveWriter(),
}, nil
}

View file

@ -1,38 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// FIXME: This should be made internal!
package postgres
import (
"database/sql"
)
// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement.
type statementList []struct {
statement **sql.Stmt
sql string
}
// prepare the SQL for each statement in the list and assign the result to the prepared statement.
func (s statementList) prepare(db *sql.DB) (err error) {
for _, statement := range s {
if *statement.statement, err = db.Prepare(statement.sql); err != nil {
return
}
}
return
}

View file

@ -1,36 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"database/sql"
)
type statements struct {
media mediaStatements
thumbnail thumbnailStatements
}
func (s *statements) prepare(db *sql.DB) (err error) {
if err = s.media.prepare(db); err != nil {
return
}
if err = s.thumbnail.prepare(db); err != nil {
return
}
return
}

View file

@ -21,6 +21,8 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
"github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -63,7 +65,7 @@ SELECT content_type, file_size_bytes, creation_ts FROM mediaapi_thumbnail WHERE
// Note: this selects all thumbnails for a media_origin and media_id // Note: this selects all thumbnails for a media_origin and media_id
const selectThumbnailsSQL = ` const selectThumbnailsSQL = `
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 ORDER BY creation_ts ASC
` `
type thumbnailStatements struct { type thumbnailStatements struct {
@ -72,24 +74,25 @@ type thumbnailStatements struct {
selectThumbnailsStmt *sql.Stmt selectThumbnailsStmt *sql.Stmt
} }
func (s *thumbnailStatements) prepare(db *sql.DB) (err error) { func NewPostgresThumbnailsTable(db *sql.DB) (tables.Thumbnails, error) {
_, err = db.Exec(thumbnailSchema) s := &thumbnailStatements{}
_, err := db.Exec(thumbnailSchema)
if err != nil { if err != nil {
return return nil, err
} }
return statementList{ return s, sqlutil.StatementList{
{&s.insertThumbnailStmt, insertThumbnailSQL}, {&s.insertThumbnailStmt, insertThumbnailSQL},
{&s.selectThumbnailStmt, selectThumbnailSQL}, {&s.selectThumbnailStmt, selectThumbnailSQL},
{&s.selectThumbnailsStmt, selectThumbnailsSQL}, {&s.selectThumbnailsStmt, selectThumbnailsSQL},
}.prepare(db) }.Prepare(db)
} }
func (s *thumbnailStatements) insertThumbnail( func (s *thumbnailStatements) InsertThumbnail(
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata, ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata,
) error { ) error {
thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000) thumbnailMetadata.MediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
_, err := s.insertThumbnailStmt.ExecContext( _, err := sqlutil.TxStmtContext(ctx, txn, s.insertThumbnailStmt).ExecContext(
ctx, ctx,
thumbnailMetadata.MediaMetadata.MediaID, thumbnailMetadata.MediaMetadata.MediaID,
thumbnailMetadata.MediaMetadata.Origin, thumbnailMetadata.MediaMetadata.Origin,
@ -103,8 +106,9 @@ func (s *thumbnailStatements) insertThumbnail(
return err return err
} }
func (s *thumbnailStatements) selectThumbnail( func (s *thumbnailStatements) SelectThumbnail(
ctx context.Context, ctx context.Context,
txn *sql.Tx,
mediaID types.MediaID, mediaID types.MediaID,
mediaOrigin gomatrixserverlib.ServerName, mediaOrigin gomatrixserverlib.ServerName,
width, height int, width, height int,
@ -121,7 +125,7 @@ func (s *thumbnailStatements) selectThumbnail(
ResizeMethod: resizeMethod, ResizeMethod: resizeMethod,
}, },
} }
err := s.selectThumbnailStmt.QueryRowContext( err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailStmt).QueryRowContext(
ctx, ctx,
thumbnailMetadata.MediaMetadata.MediaID, thumbnailMetadata.MediaMetadata.MediaID,
thumbnailMetadata.MediaMetadata.Origin, thumbnailMetadata.MediaMetadata.Origin,
@ -136,10 +140,10 @@ func (s *thumbnailStatements) selectThumbnail(
return &thumbnailMetadata, err return &thumbnailMetadata, err
} }
func (s *thumbnailStatements) selectThumbnails( func (s *thumbnailStatements) SelectThumbnails(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) ([]*types.ThumbnailMetadata, error) { ) ([]*types.ThumbnailMetadata, error) {
rows, err := s.selectThumbnailsStmt.QueryContext( rows, err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailsStmt).QueryContext(
ctx, mediaID, mediaOrigin, ctx, mediaID, mediaOrigin,
) )
if err != nil { if err != nil {

View file

@ -1,5 +1,4 @@
// Copyright 2017-2018 New Vector Ltd // Copyright 2022 The Matrix.org Foundation C.I.C.
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -13,54 +12,38 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package postgres package shared
import ( import (
"context" "context"
"database/sql" "database/sql"
// Import the postgres database driver.
_ "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
"github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
// Database is used to store metadata about a repository of media files.
type Database struct { type Database struct {
statements statements DB *sql.DB
db *sql.DB Writer sqlutil.Writer
} MediaRepository tables.MediaRepository
Thumbnails tables.Thumbnails
// Open opens a postgres database.
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
var d Database
var err error
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
if err = d.statements.prepare(d.db); err != nil {
return nil, err
}
return &d, nil
} }
// StoreMediaMetadata inserts the metadata about the uploaded media into the database. // StoreMediaMetadata inserts the metadata about the uploaded media into the database.
// Returns an error if the combination of MediaID and Origin are not unique in the table. // Returns an error if the combination of MediaID and Origin are not unique in the table.
func (d *Database) StoreMediaMetadata( func (d Database) StoreMediaMetadata(ctx context.Context, mediaMetadata *types.MediaMetadata) error {
ctx context.Context, mediaMetadata *types.MediaMetadata, return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
) error { return d.MediaRepository.InsertMedia(ctx, txn, mediaMetadata)
return d.statements.media.insertMedia(ctx, mediaMetadata) })
} }
// GetMediaMetadata returns metadata about media stored on this server. // GetMediaMetadata returns metadata about media stored on this server.
// The media could have been uploaded to this server or fetched from another server and cached here. // The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there is no metadata associated with this media. // Returns nil metadata if there is no metadata associated with this media.
func (d *Database) GetMediaMetadata( func (d Database) GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) {
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, mediaMetadata, err := d.MediaRepository.SelectMedia(ctx, nil, mediaID, mediaOrigin)
) (*types.MediaMetadata, error) {
mediaMetadata, err := d.statements.media.selectMedia(ctx, mediaID, mediaOrigin)
if err != nil && err == sql.ErrNoRows { if err != nil && err == sql.ErrNoRows {
return nil, nil return nil, nil
} }
@ -70,10 +53,8 @@ func (d *Database) GetMediaMetadata(
// GetMediaMetadataByHash returns metadata about media stored on this server. // GetMediaMetadataByHash returns metadata about media stored on this server.
// The media could have been uploaded to this server or fetched from another server and cached here. // The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there is no metadata associated with this media. // Returns nil metadata if there is no metadata associated with this media.
func (d *Database) GetMediaMetadataByHash( func (d Database) GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) {
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, mediaMetadata, err := d.MediaRepository.SelectMediaByHash(ctx, nil, mediaHash, mediaOrigin)
) (*types.MediaMetadata, error) {
mediaMetadata, err := d.statements.media.selectMediaByHash(ctx, mediaHash, mediaOrigin)
if err != nil && err == sql.ErrNoRows { if err != nil && err == sql.ErrNoRows {
return nil, nil return nil, nil
} }
@ -82,40 +63,36 @@ func (d *Database) GetMediaMetadataByHash(
// StoreThumbnail inserts the metadata about the thumbnail into the database. // StoreThumbnail inserts the metadata about the thumbnail into the database.
// Returns an error if the combination of MediaID and Origin are not unique in the table. // Returns an error if the combination of MediaID and Origin are not unique in the table.
func (d *Database) StoreThumbnail( func (d Database) StoreThumbnail(ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata) error {
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata, return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
) error { return d.Thumbnails.InsertThumbnail(ctx, txn, thumbnailMetadata)
return d.statements.thumbnail.insertThumbnail(ctx, thumbnailMetadata) })
} }
// GetThumbnail returns metadata about a specific thumbnail. // GetThumbnail returns metadata about a specific thumbnail.
// The media could have been uploaded to this server or fetched from another server and cached here. // The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there is no metadata associated with this thumbnail. // Returns nil metadata if there is no metadata associated with this thumbnail.
func (d *Database) GetThumbnail( func (d Database) GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error) {
ctx context.Context, metadata, err := d.Thumbnails.SelectThumbnail(ctx, nil, mediaID, mediaOrigin, width, height, resizeMethod)
mediaID types.MediaID, if err != nil {
mediaOrigin gomatrixserverlib.ServerName, if err == sql.ErrNoRows {
width, height int, return nil, nil
resizeMethod string, }
) (*types.ThumbnailMetadata, error) { return nil, err
thumbnailMetadata, err := d.statements.thumbnail.selectThumbnail(
ctx, mediaID, mediaOrigin, width, height, resizeMethod,
)
if err != nil && err == sql.ErrNoRows {
return nil, nil
} }
return thumbnailMetadata, err return metadata, err
} }
// GetThumbnails returns metadata about all thumbnails for a specific media stored on this server. // GetThumbnails returns metadata about all thumbnails for a specific media stored on this server.
// The media could have been uploaded to this server or fetched from another server and cached here. // The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there are no thumbnails associated with this media. // Returns nil metadata if there are no thumbnails associated with this media.
func (d *Database) GetThumbnails( func (d Database) GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) ([]*types.ThumbnailMetadata, error) {
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, metadatas, err := d.Thumbnails.SelectThumbnails(ctx, nil, mediaID, mediaOrigin)
) ([]*types.ThumbnailMetadata, error) { if err != nil {
thumbnails, err := d.statements.thumbnail.selectThumbnails(ctx, mediaID, mediaOrigin) if err == sql.ErrNoRows {
if err != nil && err == sql.ErrNoRows { return nil, nil
return nil, nil }
return nil, err
} }
return thumbnails, err return metadatas, err
} }

View file

@ -21,6 +21,7 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
"github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -66,57 +67,53 @@ SELECT content_type, file_size_bytes, creation_ts, upload_name, media_id, user_i
type mediaStatements struct { type mediaStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.Writer
insertMediaStmt *sql.Stmt insertMediaStmt *sql.Stmt
selectMediaStmt *sql.Stmt selectMediaStmt *sql.Stmt
selectMediaByHashStmt *sql.Stmt selectMediaByHashStmt *sql.Stmt
} }
func (s *mediaStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { func NewSQLiteMediaRepositoryTable(db *sql.DB) (tables.MediaRepository, error) {
s.db = db s := &mediaStatements{
s.writer = writer db: db,
}
_, err = db.Exec(mediaSchema) _, err := db.Exec(mediaSchema)
if err != nil { if err != nil {
return return nil, err
} }
return statementList{ return s, sqlutil.StatementList{
{&s.insertMediaStmt, insertMediaSQL}, {&s.insertMediaStmt, insertMediaSQL},
{&s.selectMediaStmt, selectMediaSQL}, {&s.selectMediaStmt, selectMediaSQL},
{&s.selectMediaByHashStmt, selectMediaByHashSQL}, {&s.selectMediaByHashStmt, selectMediaByHashSQL},
}.prepare(db) }.Prepare(db)
} }
func (s *mediaStatements) insertMedia( func (s *mediaStatements) InsertMedia(
ctx context.Context, mediaMetadata *types.MediaMetadata, ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata,
) error { ) error {
mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000) mediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { _, err := sqlutil.TxStmtContext(ctx, txn, s.insertMediaStmt).ExecContext(
stmt := sqlutil.TxStmt(txn, s.insertMediaStmt) ctx,
_, err := stmt.ExecContext( mediaMetadata.MediaID,
ctx, mediaMetadata.Origin,
mediaMetadata.MediaID, mediaMetadata.ContentType,
mediaMetadata.Origin, mediaMetadata.FileSizeBytes,
mediaMetadata.ContentType, mediaMetadata.CreationTimestamp,
mediaMetadata.FileSizeBytes, mediaMetadata.UploadName,
mediaMetadata.CreationTimestamp, mediaMetadata.Base64Hash,
mediaMetadata.UploadName, mediaMetadata.UserID,
mediaMetadata.Base64Hash, )
mediaMetadata.UserID, return err
)
return err
})
} }
func (s *mediaStatements) selectMedia( func (s *mediaStatements) SelectMedia(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) { ) (*types.MediaMetadata, error) {
mediaMetadata := types.MediaMetadata{ mediaMetadata := types.MediaMetadata{
MediaID: mediaID, MediaID: mediaID,
Origin: mediaOrigin, Origin: mediaOrigin,
} }
err := s.selectMediaStmt.QueryRowContext( err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaStmt).QueryRowContext(
ctx, mediaMetadata.MediaID, mediaMetadata.Origin, ctx, mediaMetadata.MediaID, mediaMetadata.Origin,
).Scan( ).Scan(
&mediaMetadata.ContentType, &mediaMetadata.ContentType,
@ -129,14 +126,14 @@ func (s *mediaStatements) selectMedia(
return &mediaMetadata, err return &mediaMetadata, err
} }
func (s *mediaStatements) selectMediaByHash( func (s *mediaStatements) SelectMediaByHash(
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) { ) (*types.MediaMetadata, error) {
mediaMetadata := types.MediaMetadata{ mediaMetadata := types.MediaMetadata{
Base64Hash: mediaHash, Base64Hash: mediaHash,
Origin: mediaOrigin, Origin: mediaOrigin,
} }
err := s.selectMediaStmt.QueryRowContext( err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaByHashStmt).QueryRowContext(
ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin, ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin,
).Scan( ).Scan(
&mediaMetadata.ContentType, &mediaMetadata.ContentType,

View file

@ -16,23 +16,30 @@
package sqlite3 package sqlite3
import ( import (
"database/sql" // Import the postgres database driver.
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/shared"
"github.com/matrix-org/dendrite/setup/config"
) )
type statements struct { // NewDatabase opens a SQLIte database.
media mediaStatements func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) {
thumbnail thumbnailStatements db, err := sqlutil.Open(dbProperties)
} if err != nil {
return nil, err
func (s *statements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
if err = s.media.prepare(db, writer); err != nil {
return
} }
if err = s.thumbnail.prepare(db, writer); err != nil { mediaRepo, err := NewSQLiteMediaRepositoryTable(db)
return if err != nil {
return nil, err
} }
thumbnails, err := NewSQLiteThumbnailsTable(db)
return if err != nil {
return nil, err
}
return &shared.Database{
MediaRepository: mediaRepo,
Thumbnails: thumbnails,
DB: db,
Writer: sqlutil.NewExclusiveWriter(),
}, nil
} }

View file

@ -1,38 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// FIXME: This should be made internal!
package sqlite3
import (
"database/sql"
)
// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement.
type statementList []struct {
statement **sql.Stmt
sql string
}
// prepare the SQL for each statement in the list and assign the result to the prepared statement.
func (s statementList) prepare(db *sql.DB) (err error) {
for _, statement := range s {
if *statement.statement, err = db.Prepare(statement.sql); err != nil {
return
}
}
return
}

View file

@ -1,123 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
// Import the postgres database driver.
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
// Database is used to store metadata about a repository of media files.
type Database struct {
statements statements
db *sql.DB
writer sqlutil.Writer
}
// Open opens a postgres database.
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
d := Database{
writer: sqlutil.NewExclusiveWriter(),
}
var err error
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
if err = d.statements.prepare(d.db, d.writer); err != nil {
return nil, err
}
return &d, nil
}
// StoreMediaMetadata inserts the metadata about the uploaded media into the database.
// Returns an error if the combination of MediaID and Origin are not unique in the table.
func (d *Database) StoreMediaMetadata(
ctx context.Context, mediaMetadata *types.MediaMetadata,
) error {
return d.statements.media.insertMedia(ctx, mediaMetadata)
}
// GetMediaMetadata returns metadata about media stored on this server.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there is no metadata associated with this media.
func (d *Database) GetMediaMetadata(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata, err := d.statements.media.selectMedia(ctx, mediaID, mediaOrigin)
if err != nil && err == sql.ErrNoRows {
return nil, nil
}
return mediaMetadata, err
}
// GetMediaMetadataByHash returns metadata about media stored on this server.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there is no metadata associated with this media.
func (d *Database) GetMediaMetadataByHash(
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata, err := d.statements.media.selectMediaByHash(ctx, mediaHash, mediaOrigin)
if err != nil && err == sql.ErrNoRows {
return nil, nil
}
return mediaMetadata, err
}
// StoreThumbnail inserts the metadata about the thumbnail into the database.
// Returns an error if the combination of MediaID and Origin are not unique in the table.
func (d *Database) StoreThumbnail(
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
) error {
return d.statements.thumbnail.insertThumbnail(ctx, thumbnailMetadata)
}
// GetThumbnail returns metadata about a specific thumbnail.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there is no metadata associated with this thumbnail.
func (d *Database) GetThumbnail(
ctx context.Context,
mediaID types.MediaID,
mediaOrigin gomatrixserverlib.ServerName,
width, height int,
resizeMethod string,
) (*types.ThumbnailMetadata, error) {
thumbnailMetadata, err := d.statements.thumbnail.selectThumbnail(
ctx, mediaID, mediaOrigin, width, height, resizeMethod,
)
if err != nil && err == sql.ErrNoRows {
return nil, nil
}
return thumbnailMetadata, err
}
// GetThumbnails returns metadata about all thumbnails for a specific media stored on this server.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there are no thumbnails associated with this media.
func (d *Database) GetThumbnails(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) ([]*types.ThumbnailMetadata, error) {
thumbnails, err := d.statements.thumbnail.selectThumbnails(ctx, mediaID, mediaOrigin)
if err != nil && err == sql.ErrNoRows {
return nil, nil
}
return thumbnails, err
}

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
"github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -54,55 +55,48 @@ SELECT content_type, file_size_bytes, creation_ts FROM mediaapi_thumbnail WHERE
// Note: this selects all thumbnails for a media_origin and media_id // Note: this selects all thumbnails for a media_origin and media_id
const selectThumbnailsSQL = ` const selectThumbnailsSQL = `
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 ORDER BY creation_ts ASC
` `
type thumbnailStatements struct { type thumbnailStatements struct {
db *sql.DB
writer sqlutil.Writer
insertThumbnailStmt *sql.Stmt insertThumbnailStmt *sql.Stmt
selectThumbnailStmt *sql.Stmt selectThumbnailStmt *sql.Stmt
selectThumbnailsStmt *sql.Stmt selectThumbnailsStmt *sql.Stmt
} }
func (s *thumbnailStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { func NewSQLiteThumbnailsTable(db *sql.DB) (tables.Thumbnails, error) {
_, err = db.Exec(thumbnailSchema) s := &thumbnailStatements{}
_, err := db.Exec(thumbnailSchema)
if err != nil { if err != nil {
return return nil, err
} }
s.db = db
s.writer = writer
return statementList{ return s, sqlutil.StatementList{
{&s.insertThumbnailStmt, insertThumbnailSQL}, {&s.insertThumbnailStmt, insertThumbnailSQL},
{&s.selectThumbnailStmt, selectThumbnailSQL}, {&s.selectThumbnailStmt, selectThumbnailSQL},
{&s.selectThumbnailsStmt, selectThumbnailsSQL}, {&s.selectThumbnailsStmt, selectThumbnailsSQL},
}.prepare(db) }.Prepare(db)
} }
func (s *thumbnailStatements) insertThumbnail( func (s *thumbnailStatements) InsertThumbnail(ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata) error {
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata, thumbnailMetadata.MediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
) error { _, err := sqlutil.TxStmtContext(ctx, txn, s.insertThumbnailStmt).ExecContext(
thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000) ctx,
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { thumbnailMetadata.MediaMetadata.MediaID,
stmt := sqlutil.TxStmt(txn, s.insertThumbnailStmt) thumbnailMetadata.MediaMetadata.Origin,
_, err := stmt.ExecContext( thumbnailMetadata.MediaMetadata.ContentType,
ctx, thumbnailMetadata.MediaMetadata.FileSizeBytes,
thumbnailMetadata.MediaMetadata.MediaID, thumbnailMetadata.MediaMetadata.CreationTimestamp,
thumbnailMetadata.MediaMetadata.Origin, thumbnailMetadata.ThumbnailSize.Width,
thumbnailMetadata.MediaMetadata.ContentType, thumbnailMetadata.ThumbnailSize.Height,
thumbnailMetadata.MediaMetadata.FileSizeBytes, thumbnailMetadata.ThumbnailSize.ResizeMethod,
thumbnailMetadata.MediaMetadata.CreationTimestamp, )
thumbnailMetadata.ThumbnailSize.Width, return err
thumbnailMetadata.ThumbnailSize.Height,
thumbnailMetadata.ThumbnailSize.ResizeMethod,
)
return err
})
} }
func (s *thumbnailStatements) selectThumbnail( func (s *thumbnailStatements) SelectThumbnail(
ctx context.Context, ctx context.Context,
txn *sql.Tx,
mediaID types.MediaID, mediaID types.MediaID,
mediaOrigin gomatrixserverlib.ServerName, mediaOrigin gomatrixserverlib.ServerName,
width, height int, width, height int,
@ -119,7 +113,7 @@ func (s *thumbnailStatements) selectThumbnail(
ResizeMethod: resizeMethod, ResizeMethod: resizeMethod,
}, },
} }
err := s.selectThumbnailStmt.QueryRowContext( err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailStmt).QueryRowContext(
ctx, ctx,
thumbnailMetadata.MediaMetadata.MediaID, thumbnailMetadata.MediaMetadata.MediaID,
thumbnailMetadata.MediaMetadata.Origin, thumbnailMetadata.MediaMetadata.Origin,
@ -134,10 +128,11 @@ func (s *thumbnailStatements) selectThumbnail(
return &thumbnailMetadata, err return &thumbnailMetadata, err
} }
func (s *thumbnailStatements) selectThumbnails( func (s *thumbnailStatements) SelectThumbnails(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, ctx context.Context, txn *sql.Tx, mediaID types.MediaID,
mediaOrigin gomatrixserverlib.ServerName,
) ([]*types.ThumbnailMetadata, error) { ) ([]*types.ThumbnailMetadata, error) {
rows, err := s.selectThumbnailsStmt.QueryContext( rows, err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailsStmt).QueryContext(
ctx, mediaID, mediaOrigin, ctx, mediaID, mediaOrigin,
) )
if err != nil { if err != nil {

View file

@ -25,13 +25,13 @@ import (
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
) )
// Open opens a postgres database. // NewMediaAPIDatasource opens a database connection.
func Open(dbProperties *config.DatabaseOptions) (Database, error) { func NewMediaAPIDatasource(dbProperties *config.DatabaseOptions) (Database, error) {
switch { switch {
case dbProperties.ConnectionString.IsSQLite(): case dbProperties.ConnectionString.IsSQLite():
return sqlite3.Open(dbProperties) return sqlite3.NewDatabase(dbProperties)
case dbProperties.ConnectionString.IsPostgres(): case dbProperties.ConnectionString.IsPostgres():
return postgres.Open(dbProperties) return postgres.NewDatabase(dbProperties)
default: default:
return nil, fmt.Errorf("unexpected database type") return nil, fmt.Errorf("unexpected database type")
} }

View file

@ -0,0 +1,135 @@
package storage_test
import (
"context"
"reflect"
"testing"
"github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
)
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := storage.NewMediaAPIDatasource(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
})
if err != nil {
t.Fatalf("NewSyncServerDatasource returned %s", err)
}
return db, close
}
func TestMediaRepository(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
ctx := context.Background()
t.Run("can insert media & query media", func(t *testing.T) {
metadata := &types.MediaMetadata{
MediaID: "testing",
Origin: "localhost",
ContentType: "image/png",
FileSizeBytes: 10,
UploadName: "upload test",
Base64Hash: "dGVzdGluZw==",
UserID: "@alice:localhost",
}
if err := db.StoreMediaMetadata(ctx, metadata); err != nil {
t.Fatalf("unable to store media metadata: %v", err)
}
// query by media id
gotMetadata, err := db.GetMediaMetadata(ctx, metadata.MediaID, metadata.Origin)
if err != nil {
t.Fatalf("unable to query media metadata: %v", err)
}
if !reflect.DeepEqual(metadata, gotMetadata) {
t.Fatalf("expected metadata %+v, got %v", metadata, gotMetadata)
}
// query by media hash
gotMetadata, err = db.GetMediaMetadataByHash(ctx, metadata.Base64Hash, metadata.Origin)
if err != nil {
t.Fatalf("unable to query media metadata by hash: %v", err)
}
if !reflect.DeepEqual(metadata, gotMetadata) {
t.Fatalf("expected metadata %+v, got %v", metadata, gotMetadata)
}
})
})
}
func TestThumbnailsStorage(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
ctx := context.Background()
t.Run("can insert thumbnails & query media", func(t *testing.T) {
thumbnails := []*types.ThumbnailMetadata{
{
MediaMetadata: &types.MediaMetadata{
MediaID: "testing",
Origin: "localhost",
ContentType: "image/png",
FileSizeBytes: 6,
},
ThumbnailSize: types.ThumbnailSize{
Width: 5,
Height: 5,
ResizeMethod: types.Crop,
},
},
{
MediaMetadata: &types.MediaMetadata{
MediaID: "testing",
Origin: "localhost",
ContentType: "image/png",
FileSizeBytes: 7,
},
ThumbnailSize: types.ThumbnailSize{
Width: 1,
Height: 1,
ResizeMethod: types.Scale,
},
},
}
for i := range thumbnails {
if err := db.StoreThumbnail(ctx, thumbnails[i]); err != nil {
t.Fatalf("unable to store thumbnail metadata: %v", err)
}
}
// query by single thumbnail
gotMetadata, err := db.GetThumbnail(ctx,
thumbnails[0].MediaMetadata.MediaID,
thumbnails[0].MediaMetadata.Origin,
thumbnails[0].ThumbnailSize.Width, thumbnails[0].ThumbnailSize.Height,
thumbnails[0].ThumbnailSize.ResizeMethod,
)
if err != nil {
t.Fatalf("unable to query thumbnail metadata: %v", err)
}
if !reflect.DeepEqual(thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata) {
t.Fatalf("expected metadata %+v, got %+v", thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata)
}
if !reflect.DeepEqual(thumbnails[0].ThumbnailSize, gotMetadata.ThumbnailSize) {
t.Fatalf("expected metadata %+v, got %+v", thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata)
}
// query by all thumbnails
gotMediadatas, err := db.GetThumbnails(ctx, thumbnails[0].MediaMetadata.MediaID, thumbnails[0].MediaMetadata.Origin)
if err != nil {
t.Fatalf("unable to query media metadata by hash: %v", err)
}
if len(gotMediadatas) != len(thumbnails) {
t.Fatalf("expected %d stored thumbnail metadata, got %d", len(thumbnails), len(gotMediadatas))
}
for i := range gotMediadatas {
if !reflect.DeepEqual(thumbnails[i].MediaMetadata, gotMediadatas[i].MediaMetadata) {
t.Fatalf("expected metadata %+v, got %v", thumbnails[i].MediaMetadata, gotMediadatas[i].MediaMetadata)
}
if !reflect.DeepEqual(thumbnails[i].ThumbnailSize, gotMediadatas[i].ThumbnailSize) {
t.Fatalf("expected metadata %+v, got %v", thumbnails[i].ThumbnailSize, gotMediadatas[i].ThumbnailSize)
}
}
})
})
}

View file

@ -22,10 +22,10 @@ import (
) )
// Open opens a postgres database. // Open opens a postgres database.
func Open(dbProperties *config.DatabaseOptions) (Database, error) { func NewMediaAPIDatasource(dbProperties *config.DatabaseOptions) (Database, error) {
switch { switch {
case dbProperties.ConnectionString.IsSQLite(): case dbProperties.ConnectionString.IsSQLite():
return sqlite3.Open(dbProperties) return sqlite3.NewDatabase(dbProperties)
case dbProperties.ConnectionString.IsPostgres(): case dbProperties.ConnectionString.IsPostgres():
return nil, fmt.Errorf("can't use Postgres implementation") return nil, fmt.Errorf("can't use Postgres implementation")
default: default:

View file

@ -0,0 +1,46 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tables
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
type Thumbnails interface {
InsertThumbnail(ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata) error
SelectThumbnail(
ctx context.Context, txn *sql.Tx,
mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
width, height int,
resizeMethod string,
) (*types.ThumbnailMetadata, error)
SelectThumbnails(
ctx context.Context, txn *sql.Tx, mediaID types.MediaID,
mediaOrigin gomatrixserverlib.ServerName,
) ([]*types.ThumbnailMetadata, error)
}
type MediaRepository interface {
InsertMedia(ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata) error
SelectMedia(ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error)
SelectMediaByHash(
ctx context.Context, txn *sql.Tx,
mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error)
}

View file

@ -45,16 +45,13 @@ type RequestMethod string
// MatrixUserID is a Matrix user ID string in the form @user:domain e.g. @alice:matrix.org // MatrixUserID is a Matrix user ID string in the form @user:domain e.g. @alice:matrix.org
type MatrixUserID string type MatrixUserID string
// UnixMs is the milliseconds since the Unix epoch
type UnixMs int64
// MediaMetadata is metadata associated with a media file // MediaMetadata is metadata associated with a media file
type MediaMetadata struct { type MediaMetadata struct {
MediaID MediaID MediaID MediaID
Origin gomatrixserverlib.ServerName Origin gomatrixserverlib.ServerName
ContentType ContentType ContentType ContentType
FileSizeBytes FileSizeBytes FileSizeBytes FileSizeBytes
CreationTimestamp UnixMs CreationTimestamp gomatrixserverlib.Timestamp
UploadName Filename UploadName Filename
Base64Hash Base64Hash Base64Hash Base64Hash
UserID MatrixUserID UserID MatrixUserID

View file

@ -36,7 +36,7 @@ import (
type Notifier struct { type Notifier struct {
lock *sync.RWMutex lock *sync.RWMutex
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine // A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
roomIDToJoinedUsers map[string]userIDSet roomIDToJoinedUsers map[string]*userIDSet
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine // A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
roomIDToPeekingDevices map[string]peekingDeviceSet roomIDToPeekingDevices map[string]peekingDeviceSet
// The latest sync position // The latest sync position
@ -54,7 +54,7 @@ type Notifier struct {
// the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase). // the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase).
func NewNotifier() *Notifier { func NewNotifier() *Notifier {
return &Notifier{ return &Notifier{
roomIDToJoinedUsers: make(map[string]userIDSet), roomIDToJoinedUsers: make(map[string]*userIDSet),
roomIDToPeekingDevices: make(map[string]peekingDeviceSet), roomIDToPeekingDevices: make(map[string]peekingDeviceSet),
userDeviceStreams: make(map[string]map[string]*UserDeviceStream), userDeviceStreams: make(map[string]map[string]*UserDeviceStream),
lock: &sync.RWMutex{}, lock: &sync.RWMutex{},
@ -262,7 +262,7 @@ func (n *Notifier) SharedUsers(userID string) []string {
func (n *Notifier) _sharedUsers(userID string) []string { func (n *Notifier) _sharedUsers(userID string) []string {
n._sharedUserMap[userID] = struct{}{} n._sharedUserMap[userID] = struct{}{}
for roomID, users := range n.roomIDToJoinedUsers { for roomID, users := range n.roomIDToJoinedUsers {
if _, ok := users[userID]; !ok { if ok := users.isIn(userID); !ok {
continue continue
} }
for _, userID := range n._joinedUsers(roomID) { for _, userID := range n._joinedUsers(roomID) {
@ -282,8 +282,11 @@ func (n *Notifier) IsSharedUser(userA, userB string) bool {
defer n.lock.RUnlock() defer n.lock.RUnlock()
var okA, okB bool var okA, okB bool
for _, users := range n.roomIDToJoinedUsers { for _, users := range n.roomIDToJoinedUsers {
_, okA = users[userA] okA = users.isIn(userA)
_, okB = users[userB] if !okA {
continue
}
okB = users.isIn(userB)
if okA && okB { if okA && okB {
return true return true
} }
@ -345,11 +348,12 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) {
// This is just the bulk form of addJoinedUser // This is just the bulk form of addJoinedUser
for roomID, userIDs := range roomIDToUserIDs { for roomID, userIDs := range roomIDToUserIDs {
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok { if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
n.roomIDToJoinedUsers[roomID] = make(userIDSet, len(userIDs)) n.roomIDToJoinedUsers[roomID] = newUserIDSet(len(userIDs))
} }
for _, userID := range userIDs { for _, userID := range userIDs {
n.roomIDToJoinedUsers[roomID].add(userID) n.roomIDToJoinedUsers[roomID].add(userID)
} }
n.roomIDToJoinedUsers[roomID].precompute()
} }
} }
@ -440,16 +444,18 @@ func (n *Notifier) _fetchUserStreams(userID string) []*UserDeviceStream {
func (n *Notifier) _addJoinedUser(roomID, userID string) { func (n *Notifier) _addJoinedUser(roomID, userID string) {
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok { if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
n.roomIDToJoinedUsers[roomID] = make(userIDSet) n.roomIDToJoinedUsers[roomID] = newUserIDSet(8)
} }
n.roomIDToJoinedUsers[roomID].add(userID) n.roomIDToJoinedUsers[roomID].add(userID)
n.roomIDToJoinedUsers[roomID].precompute()
} }
func (n *Notifier) _removeJoinedUser(roomID, userID string) { func (n *Notifier) _removeJoinedUser(roomID, userID string) {
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok { if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
n.roomIDToJoinedUsers[roomID] = make(userIDSet) n.roomIDToJoinedUsers[roomID] = newUserIDSet(8)
} }
n.roomIDToJoinedUsers[roomID].remove(userID) n.roomIDToJoinedUsers[roomID].remove(userID)
n.roomIDToJoinedUsers[roomID].precompute()
} }
func (n *Notifier) JoinedUsers(roomID string) (userIDs []string) { func (n *Notifier) JoinedUsers(roomID string) (userIDs []string) {
@ -521,19 +527,52 @@ func (n *Notifier) _removeEmptyUserStreams() {
} }
// A string set, mainly existing for improving clarity of structs in this file. // A string set, mainly existing for improving clarity of structs in this file.
type userIDSet map[string]struct{} type userIDSet struct {
sync.Mutex
func (s userIDSet) add(str string) { set map[string]struct{}
s[str] = struct{}{} precomputed []string
} }
func (s userIDSet) remove(str string) { func newUserIDSet(cap int) *userIDSet {
delete(s, str) return &userIDSet{
set: make(map[string]struct{}, cap),
precomputed: nil,
}
} }
func (s userIDSet) values() (vals []string) { func (s *userIDSet) add(str string) {
vals = make([]string, 0, len(s)) s.Lock()
for str := range s { defer s.Unlock()
s.set[str] = struct{}{}
s.precomputed = s.precomputed[:0] // invalidate cache
}
func (s *userIDSet) remove(str string) {
s.Lock()
defer s.Unlock()
delete(s.set, str)
s.precomputed = s.precomputed[:0] // invalidate cache
}
func (s *userIDSet) precompute() {
s.Lock()
defer s.Unlock()
s.precomputed = s.values()
}
func (s *userIDSet) isIn(str string) bool {
s.Lock()
defer s.Unlock()
_, ok := s.set[str]
return ok
}
func (s *userIDSet) values() (vals []string) {
if len(s.precomputed) > 0 {
return s.precomputed // only return if not invalidated
}
vals = make([]string, 0, len(s.set))
for str := range s.set {
vals = append(vals, str) vals = append(vals, str)
} }
return return

View file

@ -60,7 +60,9 @@ func Context(
Headers: nil, Headers: nil,
} }
} }
filter.Rooms = append(filter.Rooms, roomID) if filter.Rooms != nil {
*filter.Rooms = append(*filter.Rooms, roomID)
}
ctx := req.Context() ctx := req.Context()
membershipRes := roomserver.QueryMembershipForUserResponse{} membershipRes := roomserver.QueryMembershipForUserResponse{}

View file

@ -262,12 +262,8 @@ func (r *messagesReq) retrieveEvents() (
clientEvents []gomatrixserverlib.ClientEvent, start, clientEvents []gomatrixserverlib.ClientEvent, start,
end types.TopologyToken, err error, end types.TopologyToken, err error,
) { ) {
eventFilter := r.filter
// Retrieve the events from the local database. // Retrieve the events from the local database.
streamEvents, err := r.db.GetEventsInTopologicalRange( streamEvents, err := r.db.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering)
r.ctx, r.from, r.to, r.roomID, eventFilter.Limit, r.backwardOrdering,
)
if err != nil { if err != nil {
err = fmt.Errorf("GetEventsInRange: %w", err) err = fmt.Errorf("GetEventsInRange: %w", err)
return return

View file

@ -104,8 +104,8 @@ type Database interface {
// DeletePeek deletes all peeks for a given room by a given user // DeletePeek deletes all peeks for a given room by a given user
// Returns an error if there was a problem communicating with the database. // Returns an error if there was a problem communicating with the database.
DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error) DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error)
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last.
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error)
// EventPositionInTopology returns the depth and stream position of the given event. // EventPositionInTopology returns the depth and stream position of the given event.
EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error) EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
// BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events. // BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events.

View file

@ -47,14 +47,10 @@ const selectBackwardExtremitiesForRoomSQL = "" +
const deleteBackwardExtremitySQL = "" + const deleteBackwardExtremitySQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
const deleteBackwardExtremitiesForRoomSQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1"
type backwardExtremitiesStatements struct { type backwardExtremitiesStatements struct {
insertBackwardExtremityStmt *sql.Stmt insertBackwardExtremityStmt *sql.Stmt
selectBackwardExtremitiesForRoomStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt
deleteBackwardExtremityStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt
deleteBackwardExtremitiesForRoomStmt *sql.Stmt
} }
func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
@ -72,9 +68,6 @@ func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremiti
if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil {
return nil, err return nil, err
} }
if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
@ -113,10 +106,3 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
return return
} }
func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID)
return err
}

View file

@ -233,9 +233,10 @@ func (s *currentRoomStateStatements) SelectCurrentState(
excludeEventIDs []string, excludeEventIDs []string,
) ([]*gomatrixserverlib.HeaderedEvent, error) { ) ([]*gomatrixserverlib.HeaderedEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt) stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt)
senders, notSenders := getSendersStateFilterFilter(stateFilter)
rows, err := stmt.QueryContext(ctx, roomID, rows, err := stmt.QueryContext(ctx, roomID,
pq.StringArray(stateFilter.Senders), pq.StringArray(senders),
pq.StringArray(stateFilter.NotSenders), pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)),
stateFilter.ContainsURL, stateFilter.ContainsURL,

View file

@ -16,21 +16,45 @@ package postgres
import ( import (
"strings" "strings"
"github.com/matrix-org/gomatrixserverlib"
) )
// filterConvertWildcardToSQL converts wildcards as defined in // filterConvertWildcardToSQL converts wildcards as defined in
// https://matrix.org/docs/spec/client_server/r0.3.0.html#post-matrix-client-r0-user-userid-filter // https://matrix.org/docs/spec/client_server/r0.3.0.html#post-matrix-client-r0-user-userid-filter
// to SQL wildcards that can be used with LIKE() // to SQL wildcards that can be used with LIKE()
func filterConvertTypeWildcardToSQL(values []string) []string { func filterConvertTypeWildcardToSQL(values *[]string) []string {
if values == nil { if values == nil {
// Return nil instead of []string{} so IS NULL can work correctly when // Return nil instead of []string{} so IS NULL can work correctly when
// the return value is passed into SQL queries // the return value is passed into SQL queries
return nil return nil
} }
ret := make([]string, len(values)) v := *values
for i := range values { ret := make([]string, len(v))
ret[i] = strings.Replace(values[i], "*", "%", -1) for i := range v {
ret[i] = strings.Replace(v[i], "*", "%", -1)
} }
return ret return ret
} }
// TODO: Replace when Dendrite uses Go 1.18
func getSendersRoomEventFilter(filter *gomatrixserverlib.RoomEventFilter) (senders []string, notSenders []string) {
if filter.Senders != nil {
senders = *filter.Senders
}
if filter.NotSenders != nil {
notSenders = *filter.NotSenders
}
return senders, notSenders
}
func getSendersStateFilterFilter(filter *gomatrixserverlib.StateFilter) (senders []string, notSenders []string) {
if filter.Senders != nil {
senders = *filter.Senders
}
if filter.NotSenders != nil {
notSenders = *filter.NotSenders
}
return senders, notSenders
}

View file

@ -56,12 +56,6 @@ const upsertMembershipSQL = "" +
" ON CONFLICT ON CONSTRAINT syncapi_memberships_unique" + " ON CONFLICT ON CONSTRAINT syncapi_memberships_unique" +
" DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6" " DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6"
const selectMembershipSQL = "" +
"SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" +
" WHERE room_id = $1 AND user_id = $2 AND membership = ANY($3)" +
" ORDER BY stream_pos DESC" +
" LIMIT 1"
const selectMembershipCountSQL = "" + const selectMembershipCountSQL = "" +
"SELECT COUNT(*) FROM (" + "SELECT COUNT(*) FROM (" +
" SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" + " SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" +
@ -69,7 +63,6 @@ const selectMembershipCountSQL = "" +
type membershipsStatements struct { type membershipsStatements struct {
upsertMembershipStmt *sql.Stmt upsertMembershipStmt *sql.Stmt
selectMembershipStmt *sql.Stmt
selectMembershipCountStmt *sql.Stmt selectMembershipCountStmt *sql.Stmt
} }
@ -82,9 +75,6 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil { if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil {
return nil, err return nil, err
} }
if s.selectMembershipStmt, err = db.Prepare(selectMembershipSQL); err != nil {
return nil, err
}
if s.selectMembershipCountStmt, err = db.Prepare(selectMembershipCountSQL); err != nil { if s.selectMembershipCountStmt, err = db.Prepare(selectMembershipCountSQL); err != nil {
return nil, err return nil, err
} }
@ -111,14 +101,6 @@ func (s *membershipsStatements) UpsertMembership(
return err return err
} }
func (s *membershipsStatements) SelectMembership(
ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string,
) (eventID string, streamPos, topologyPos types.StreamPosition, err error) {
stmt := sqlutil.TxStmt(txn, s.selectMembershipStmt)
err = stmt.QueryRowContext(ctx, roomID, userID, memberships).Scan(&eventID, &streamPos, &topologyPos)
return
}
func (s *membershipsStatements) SelectMembershipCount( func (s *membershipsStatements) SelectMembershipCount(
ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition, ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition,
) (count int, err error) { ) (count int, err error) {

View file

@ -81,6 +81,15 @@ const insertEventSQL = "" +
const selectEventsSQL = "" + const selectEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)"
const selectEventsWithFilterSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" +
" AND ( $2::text[] IS NULL OR sender = ANY($2) )" +
" AND ( $3::text[] IS NULL OR NOT(sender = ANY($3)) )" +
" AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" +
" AND ( $5::text[] IS NULL OR NOT(type LIKE ANY($5)) )" +
" AND ( $6::bool IS NULL OR contains_url = $6 )" +
" LIMIT $7"
const selectRecentEventsSQL = "" + const selectRecentEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" + " WHERE room_id = $1 AND id > $2 AND id <= $3" +
@ -153,6 +162,7 @@ const selectContextAfterEventSQL = "" +
type outputRoomEventsStatements struct { type outputRoomEventsStatements struct {
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt selectEventsStmt *sql.Stmt
selectEventsWitFilterStmt *sql.Stmt
selectMaxEventIDStmt *sql.Stmt selectMaxEventIDStmt *sql.Stmt
selectRecentEventsStmt *sql.Stmt selectRecentEventsStmt *sql.Stmt
selectRecentEventsForSyncStmt *sql.Stmt selectRecentEventsForSyncStmt *sql.Stmt
@ -174,6 +184,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertEventStmt, insertEventSQL}, {&s.insertEventStmt, insertEventSQL},
{&s.selectEventsStmt, selectEventsSQL}, {&s.selectEventsStmt, selectEventsSQL},
{&s.selectEventsWitFilterStmt, selectEventsWithFilterSQL},
{&s.selectMaxEventIDStmt, selectMaxEventIDSQL}, {&s.selectMaxEventIDStmt, selectMaxEventIDSQL},
{&s.selectRecentEventsStmt, selectRecentEventsSQL}, {&s.selectRecentEventsStmt, selectRecentEventsSQL},
{&s.selectRecentEventsForSyncStmt, selectRecentEventsForSyncSQL}, {&s.selectRecentEventsForSyncStmt, selectRecentEventsForSyncSQL},
@ -204,11 +215,11 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
stateFilter *gomatrixserverlib.StateFilter, roomIDs []string, stateFilter *gomatrixserverlib.StateFilter, roomIDs []string,
) (map[string]map[string]bool, map[string]types.StreamEvent, error) { ) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt) stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt)
senders, notSenders := getSendersStateFilterFilter(stateFilter)
rows, err := stmt.QueryContext( rows, err := stmt.QueryContext(
ctx, r.Low(), r.High(), pq.StringArray(roomIDs), ctx, r.Low(), r.High(), pq.StringArray(roomIDs),
pq.StringArray(stateFilter.Senders), pq.StringArray(senders),
pq.StringArray(stateFilter.NotSenders), pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)),
stateFilter.ContainsURL, stateFilter.ContainsURL,
@ -310,7 +321,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
// Parse content as JSON and search for an "url" key // Parse content as JSON and search for an "url" key
containsURL := false containsURL := false
var content map[string]interface{} var content map[string]interface{}
if json.Unmarshal(event.Content(), &content) != nil { if json.Unmarshal(event.Content(), &content) == nil {
// Set containsURL to true if url is present // Set containsURL to true if url is present
_, containsURL = content["url"] _, containsURL = content["url"]
} }
@ -353,10 +364,11 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
} else { } else {
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt) stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt)
} }
senders, notSenders := getSendersRoomEventFilter(eventFilter)
rows, err := stmt.QueryContext( rows, err := stmt.QueryContext(
ctx, roomID, r.Low(), r.High(), ctx, roomID, r.Low(), r.High(),
pq.StringArray(eventFilter.Senders), pq.StringArray(senders),
pq.StringArray(eventFilter.NotSenders), pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)), pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)),
eventFilter.Limit+1, eventFilter.Limit+1,
@ -398,11 +410,12 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, error) {
senders, notSenders := getSendersRoomEventFilter(eventFilter)
stmt := sqlutil.TxStmt(txn, s.selectEarlyEventsStmt) stmt := sqlutil.TxStmt(txn, s.selectEarlyEventsStmt)
rows, err := stmt.QueryContext( rows, err := stmt.QueryContext(
ctx, roomID, r.Low(), r.High(), ctx, roomID, r.Low(), r.High(),
pq.StringArray(eventFilter.Senders), pq.StringArray(senders),
pq.StringArray(eventFilter.NotSenders), pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)), pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)),
eventFilter.Limit, eventFilter.Limit,
@ -427,15 +440,52 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
// selectEvents returns the events for the given event IDs. If an event is // selectEvents returns the events for the given event IDs. If an event is
// missing from the database, it will be omitted. // missing from the database, it will be omitted.
func (s *outputRoomEventsStatements) SelectEvents( func (s *outputRoomEventsStatements) SelectEvents(
ctx context.Context, txn *sql.Tx, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool,
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectEventsStmt) var (
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) stmt *sql.Stmt
rows *sql.Rows
err error
)
if filter == nil {
stmt = sqlutil.TxStmt(txn, s.selectEventsStmt)
rows, err = stmt.QueryContext(ctx, pq.StringArray(eventIDs))
} else {
senders, notSenders := getSendersRoomEventFilter(filter)
stmt = sqlutil.TxStmt(txn, s.selectEventsWitFilterStmt)
rows, err = stmt.QueryContext(ctx,
pq.StringArray(eventIDs),
pq.StringArray(senders),
pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
filter.ContainsURL,
filter.Limit,
)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
return rowsToStreamEvents(rows) streamEvents, err := rowsToStreamEvents(rows)
if err != nil {
return nil, err
}
if preserveOrder {
eventMap := make(map[string]types.StreamEvent)
for _, ev := range streamEvents {
eventMap[ev.EventID()] = ev
}
var returnEvents []types.StreamEvent
for _, eventID := range eventIDs {
ev, ok := eventMap[eventID]
if ok {
returnEvents = append(returnEvents, ev)
}
}
return returnEvents, nil
}
return streamEvents, nil
} }
func (s *outputRoomEventsStatements) DeleteEventsForRoom( func (s *outputRoomEventsStatements) DeleteEventsForRoom(
@ -462,10 +512,11 @@ func (s *outputRoomEventsStatements) SelectContextEvent(ctx context.Context, txn
func (s *outputRoomEventsStatements) SelectContextBeforeEvent( func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter, ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter,
) (evts []*gomatrixserverlib.HeaderedEvent, err error) { ) (evts []*gomatrixserverlib.HeaderedEvent, err error) {
senders, notSenders := getSendersRoomEventFilter(filter)
rows, err := sqlutil.TxStmt(txn, s.selectContextBeforeEventStmt).QueryContext( rows, err := sqlutil.TxStmt(txn, s.selectContextBeforeEventStmt).QueryContext(
ctx, roomID, id, filter.Limit, ctx, roomID, id, filter.Limit,
pq.StringArray(filter.Senders), pq.StringArray(senders),
pq.StringArray(filter.NotSenders), pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)), pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
) )
@ -494,10 +545,11 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
func (s *outputRoomEventsStatements) SelectContextAfterEvent( func (s *outputRoomEventsStatements) SelectContextAfterEvent(
ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter, ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter,
) (lastID int, evts []*gomatrixserverlib.HeaderedEvent, err error) { ) (lastID int, evts []*gomatrixserverlib.HeaderedEvent, err error) {
senders, notSenders := getSendersRoomEventFilter(filter)
rows, err := sqlutil.TxStmt(txn, s.selectContextAfterEventStmt).QueryContext( rows, err := sqlutil.TxStmt(txn, s.selectContextAfterEventStmt).QueryContext(
ctx, roomID, id, filter.Limit, ctx, roomID, id, filter.Limit,
pq.StringArray(filter.Senders), pq.StringArray(senders),
pq.StringArray(filter.NotSenders), pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)), pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
) )

View file

@ -73,9 +73,6 @@ const selectMaxPositionInTopologySQL = "" +
"SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" + "SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" +
") ORDER BY stream_position DESC LIMIT 1" ") ORDER BY stream_position DESC LIMIT 1"
const deleteTopologyForRoomSQL = "" +
"DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1"
const selectStreamToTopologicalPositionAscSQL = "" + const selectStreamToTopologicalPositionAscSQL = "" +
"SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;" "SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;"
@ -88,7 +85,6 @@ type outputRoomEventsTopologyStatements struct {
selectEventIDsInRangeDESCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt
selectPositionInTopologyStmt *sql.Stmt selectPositionInTopologyStmt *sql.Stmt
selectMaxPositionInTopologyStmt *sql.Stmt selectMaxPositionInTopologyStmt *sql.Stmt
deleteTopologyForRoomStmt *sql.Stmt
selectStreamToTopologicalPositionAscStmt *sql.Stmt selectStreamToTopologicalPositionAscStmt *sql.Stmt
selectStreamToTopologicalPositionDescStmt *sql.Stmt selectStreamToTopologicalPositionDescStmt *sql.Stmt
} }
@ -114,9 +110,6 @@ func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) {
if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil {
return nil, err return nil, err
} }
if s.deleteTopologyForRoomStmt, err = db.Prepare(deleteTopologyForRoomSQL); err != nil {
return nil, err
}
if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil { if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil {
return nil, err return nil, err
} }
@ -148,9 +141,9 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
// is requested or not. // is requested or not.
var stmt *sql.Stmt var stmt *sql.Stmt
if chronologicalOrder { if chronologicalOrder {
stmt = s.selectEventIDsInRangeASCStmt stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeASCStmt)
} else { } else {
stmt = s.selectEventIDsInRangeDESCStmt stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeDESCStmt)
} }
// Query the event IDs. // Query the event IDs.
@ -203,10 +196,3 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
return return
} }
func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID)
return err
}

View file

@ -150,7 +150,7 @@ func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, stre
// Returns an error if there was a problem talking with the database. // Returns an error if there was a problem talking with the database.
// Does not include any transaction IDs in the returned events. // Does not include any transaction IDs in the returned events.
func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs) streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, nil, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -312,7 +312,7 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e
// Check if we have all of the event's previous events. If an event is // Check if we have all of the event's previous events. If an event is
// missing, add it to the room's backward extremities. // missing, add it to the room's backward extremities.
prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs()) prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), nil, false)
if err != nil { if err != nil {
return err return err
} }
@ -429,7 +429,8 @@ func (d *Database) updateRoomState(
func (d *Database) GetEventsInTopologicalRange( func (d *Database) GetEventsInTopologicalRange(
ctx context.Context, ctx context.Context,
from, to *types.TopologyToken, from, to *types.TopologyToken,
roomID string, limit int, roomID string,
filter *gomatrixserverlib.RoomEventFilter,
backwardOrdering bool, backwardOrdering bool,
) (events []types.StreamEvent, err error) { ) (events []types.StreamEvent, err error) {
var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition
@ -450,14 +451,14 @@ func (d *Database) GetEventsInTopologicalRange(
// Select the event IDs from the defined range. // Select the event IDs from the defined range.
var eIDs []string var eIDs []string
eIDs, err = d.Topology.SelectEventIDsInRange( eIDs, err = d.Topology.SelectEventIDsInRange(
ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, limit, !backwardOrdering, ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering,
) )
if err != nil { if err != nil {
return return
} }
// Retrieve the events' contents using their IDs. // Retrieve the events' contents using their IDs.
events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs) events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs, filter, true)
return return
} }
@ -619,7 +620,7 @@ func (d *Database) fetchMissingStateEvents(
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, error) {
// Fetch from the events table first so we pick up the stream ID for the // Fetch from the events table first so we pick up the stream ID for the
// event. // event.
events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs) events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, nil, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -687,6 +688,9 @@ func (d *Database) GetStateDeltas(
// user has ever interacted with — joined to, kicked/banned from, left. // user has ever interacted with — joined to, kicked/banned from, left.
memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID) memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID)
if err != nil { if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
}
return nil, nil, err return nil, nil, err
} }
@ -704,17 +708,23 @@ func (d *Database) GetStateDeltas(
// get all the state events ever (i.e. for all available rooms) between these two positions // get all the state events ever (i.e. for all available rooms) between these two positions
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs) stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs)
if err != nil { if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
}
return nil, nil, err return nil, nil, err
} }
state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap)
if err != nil { if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
}
return nil, nil, err return nil, nil, err
} }
// find out which rooms this user is peeking, if any. // find out which rooms this user is peeking, if any.
// We do this before joins so any peeks get overwritten // We do this before joins so any peeks get overwritten
peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r) peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r)
if err != nil { if err != nil && err != sql.ErrNoRows {
return nil, nil, err return nil, nil, err
} }
@ -725,6 +735,9 @@ func (d *Database) GetStateDeltas(
var s []types.StreamEvent var s []types.StreamEvent
s, err = d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter) s, err = d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter)
if err != nil { if err != nil {
if err == sql.ErrNoRows {
continue
}
return nil, nil, err return nil, nil, err
} }
state[peek.RoomID] = s state[peek.RoomID] = s
@ -752,6 +765,9 @@ func (d *Database) GetStateDeltas(
var s []types.StreamEvent var s []types.StreamEvent
s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter) s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter)
if err != nil { if err != nil {
if err == sql.ErrNoRows {
continue
}
return nil, nil, err return nil, nil, err
} }
state[roomID] = s state[roomID] = s
@ -802,6 +818,9 @@ func (d *Database) GetStateDeltasForFullStateSync(
// user has ever interacted with — joined to, kicked/banned from, left. // user has ever interacted with — joined to, kicked/banned from, left.
memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID) memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID)
if err != nil { if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
}
return nil, nil, err return nil, nil, err
} }
@ -818,7 +837,7 @@ func (d *Database) GetStateDeltasForFullStateSync(
deltas := make(map[string]types.StateDelta) deltas := make(map[string]types.StateDelta)
peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r) peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r)
if err != nil { if err != nil && err != sql.ErrNoRows {
return nil, nil, err return nil, nil, err
} }
@ -827,6 +846,9 @@ func (d *Database) GetStateDeltasForFullStateSync(
if !peek.Deleted { if !peek.Deleted {
s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter) s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter)
if stateErr != nil { if stateErr != nil {
if stateErr == sql.ErrNoRows {
continue
}
return nil, nil, stateErr return nil, nil, stateErr
} }
deltas[peek.RoomID] = types.StateDelta{ deltas[peek.RoomID] = types.StateDelta{
@ -840,10 +862,16 @@ func (d *Database) GetStateDeltasForFullStateSync(
// Get all the state events ever between these two positions // Get all the state events ever between these two positions
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs) stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs)
if err != nil { if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
}
return nil, nil, err return nil, nil, err
} }
state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap)
if err != nil { if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
}
return nil, nil, err return nil, nil, err
} }
@ -868,6 +896,9 @@ func (d *Database) GetStateDeltasForFullStateSync(
for _, joinedRoomID := range joinedRoomIDs { for _, joinedRoomID := range joinedRoomIDs {
s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilter) s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilter)
if stateErr != nil { if stateErr != nil {
if stateErr == sql.ErrNoRows {
continue
}
return nil, nil, stateErr return nil, nil, stateErr
} }
deltas[joinedRoomID] = types.StateDelta{ deltas[joinedRoomID] = types.StateDelta{

View file

@ -41,23 +41,23 @@ const insertAccountDataSQL = "" +
" ON CONFLICT (user_id, room_id, type) DO UPDATE" + " ON CONFLICT (user_id, room_id, type) DO UPDATE" +
" SET id = $5" " SET id = $5"
// further parameters are added by prepareWithFilters
const selectAccountDataInRangeSQL = "" + const selectAccountDataInRangeSQL = "" +
"SELECT room_id, type FROM syncapi_account_data_type" + "SELECT room_id, type FROM syncapi_account_data_type" +
" WHERE user_id = $1 AND id > $2 AND id <= $3" + " WHERE user_id = $1 AND id > $2 AND id <= $3"
" ORDER BY id ASC"
const selectMaxAccountDataIDSQL = "" + const selectMaxAccountDataIDSQL = "" +
"SELECT MAX(id) FROM syncapi_account_data_type" "SELECT MAX(id) FROM syncapi_account_data_type"
type accountDataStatements struct { type accountDataStatements struct {
db *sql.DB db *sql.DB
streamIDStatements *streamIDStatements streamIDStatements *StreamIDStatements
insertAccountDataStmt *sql.Stmt insertAccountDataStmt *sql.Stmt
selectMaxAccountDataIDStmt *sql.Stmt selectMaxAccountDataIDStmt *sql.Stmt
selectAccountDataInRangeStmt *sql.Stmt selectAccountDataInRangeStmt *sql.Stmt
} }
func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) { func NewSqliteAccountDataTable(db *sql.DB, streamID *StreamIDStatements) (tables.AccountData, error) {
s := &accountDataStatements{ s := &accountDataStatements{
db: db, db: db,
streamIDStatements: streamID, streamIDStatements: streamID,
@ -94,18 +94,24 @@ func (s *accountDataStatements) SelectAccountDataInRange(
ctx context.Context, ctx context.Context,
userID string, userID string,
r types.Range, r types.Range,
accountDataFilterPart *gomatrixserverlib.EventFilter, filter *gomatrixserverlib.EventFilter,
) (data map[string][]string, err error) { ) (data map[string][]string, err error) {
data = make(map[string][]string) data = make(map[string][]string)
stmt, params, err := prepareWithFilters(
s.db, nil, selectAccountDataInRangeSQL,
[]interface{}{
userID, r.Low(), r.High(),
},
filter.Senders, filter.NotSenders,
filter.Types, filter.NotTypes,
[]string{}, nil, filter.Limit, FilterOrderAsc)
rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High()) rows, err := stmt.QueryContext(ctx, params...)
if err != nil { if err != nil {
return return
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed")
var entries int
for rows.Next() { for rows.Next() {
var dataType string var dataType string
var roomID string var roomID string
@ -114,31 +120,11 @@ func (s *accountDataStatements) SelectAccountDataInRange(
return return
} }
// check if we should add this by looking at the filter.
// It would be nice if we could do this in SQL-land, but the mix of variadic
// and positional parameters makes the query annoyingly hard to do, it's easier
// and clearer to do it in Go-land. If there are no filters for [not]types then
// this gets skipped.
for _, includeType := range accountDataFilterPart.Types {
if includeType != dataType { // TODO: wildcard support
continue
}
}
for _, excludeType := range accountDataFilterPart.NotTypes {
if excludeType == dataType { // TODO: wildcard support
continue
}
}
if len(data[roomID]) > 0 { if len(data[roomID]) > 0 {
data[roomID] = append(data[roomID], dataType) data[roomID] = append(data[roomID], dataType)
} else { } else {
data[roomID] = []string{dataType} data[roomID] = []string{dataType}
} }
entries++
if entries >= accountDataFilterPart.Limit {
break
}
} }
return data, nil return data, nil

View file

@ -47,15 +47,11 @@ const selectBackwardExtremitiesForRoomSQL = "" +
const deleteBackwardExtremitySQL = "" + const deleteBackwardExtremitySQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
const deleteBackwardExtremitiesForRoomSQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1"
type backwardExtremitiesStatements struct { type backwardExtremitiesStatements struct {
db *sql.DB db *sql.DB
insertBackwardExtremityStmt *sql.Stmt insertBackwardExtremityStmt *sql.Stmt
selectBackwardExtremitiesForRoomStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt
deleteBackwardExtremityStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt
deleteBackwardExtremitiesForRoomStmt *sql.Stmt
} }
func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
@ -75,9 +71,6 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities
if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil {
return nil, err return nil, err
} }
if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
@ -116,10 +109,3 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
return err return err
} }
func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID)
return err
}

View file

@ -90,7 +90,7 @@ const selectEventsWithEventIDsSQL = "" +
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
db *sql.DB db *sql.DB
streamIDStatements *streamIDStatements streamIDStatements *StreamIDStatements
upsertRoomStateStmt *sql.Stmt upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt
deleteRoomStateForRoomStmt *sql.Stmt deleteRoomStateForRoomStmt *sql.Stmt
@ -100,7 +100,7 @@ type currentRoomStateStatements struct {
selectStateEventStmt *sql.Stmt selectStateEventStmt *sql.Stmt
} }
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) { func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
s := &currentRoomStateStatements{ s := &currentRoomStateStatements{
db: db, db: db,
streamIDStatements: streamID, streamIDStatements: streamID,
@ -220,7 +220,7 @@ func (s *currentRoomStateStatements) SelectCurrentState(
}, },
stateFilter.Senders, stateFilter.NotSenders, stateFilter.Senders, stateFilter.NotSenders,
stateFilter.Types, stateFilter.NotTypes, stateFilter.Types, stateFilter.NotTypes,
excludeEventIDs, stateFilter.Limit, FilterOrderNone, excludeEventIDs, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderNone,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("s.prepareWithFilters: %w", err) return nil, fmt.Errorf("s.prepareWithFilters: %w", err)

View file

@ -25,34 +25,53 @@ const (
// parts. // parts.
func prepareWithFilters( func prepareWithFilters(
db *sql.DB, txn *sql.Tx, query string, params []interface{}, db *sql.DB, txn *sql.Tx, query string, params []interface{},
senders, notsenders, types, nottypes []string, excludeEventIDs []string, senders, notsenders, types, nottypes *[]string, excludeEventIDs []string,
limit int, order FilterOrder, containsURL *bool, limit int, order FilterOrder,
) (*sql.Stmt, []interface{}, error) { ) (*sql.Stmt, []interface{}, error) {
offset := len(params) offset := len(params)
if count := len(senders); count > 0 { if senders != nil {
query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset) if count := len(*senders); count > 0 {
for _, v := range senders { query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset)
params, offset = append(params, v), offset+1 for _, v := range *senders {
params, offset = append(params, v), offset+1
}
} else {
query += ` AND sender = ""`
} }
} }
if count := len(notsenders); count > 0 { if notsenders != nil {
query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset) if count := len(*notsenders); count > 0 {
for _, v := range notsenders { query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
params, offset = append(params, v), offset+1 for _, v := range *notsenders {
params, offset = append(params, v), offset+1
}
} else {
query += ` AND sender NOT = ""`
} }
} }
if count := len(types); count > 0 { if types != nil {
query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset) if count := len(*types); count > 0 {
for _, v := range types { query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset)
params, offset = append(params, v), offset+1 for _, v := range *types {
params, offset = append(params, v), offset+1
}
} else {
query += ` AND type = ""`
} }
} }
if count := len(nottypes); count > 0 { if nottypes != nil {
query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset) if count := len(*nottypes); count > 0 {
for _, v := range nottypes { query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
params, offset = append(params, v), offset+1 for _, v := range *nottypes {
params, offset = append(params, v), offset+1
}
} else {
query += ` AND type NOT = ""`
} }
} }
if containsURL != nil {
query += fmt.Sprintf(" AND contains_url = %v", *containsURL)
}
if count := len(excludeEventIDs); count > 0 { if count := len(excludeEventIDs); count > 0 {
query += " AND event_id NOT IN " + sqlutil.QueryVariadicOffset(count, offset) query += " AND event_id NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range excludeEventIDs { for _, v := range excludeEventIDs {

View file

@ -59,14 +59,14 @@ const selectMaxInviteIDSQL = "" +
type inviteEventsStatements struct { type inviteEventsStatements struct {
db *sql.DB db *sql.DB
streamIDStatements *streamIDStatements streamIDStatements *StreamIDStatements
insertInviteEventStmt *sql.Stmt insertInviteEventStmt *sql.Stmt
selectInviteEventsInRangeStmt *sql.Stmt selectInviteEventsInRangeStmt *sql.Stmt
deleteInviteEventStmt *sql.Stmt deleteInviteEventStmt *sql.Stmt
selectMaxInviteIDStmt *sql.Stmt selectMaxInviteIDStmt *sql.Stmt
} }
func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) { func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Invites, error) {
s := &inviteEventsStatements{ s := &inviteEventsStatements{
db: db, db: db,
streamIDStatements: streamID, streamIDStatements: streamID,

View file

@ -18,7 +18,6 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"strings"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
@ -57,12 +56,6 @@ const upsertMembershipSQL = "" +
" ON CONFLICT (room_id, user_id, membership)" + " ON CONFLICT (room_id, user_id, membership)" +
" DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6" " DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6"
const selectMembershipSQL = "" +
"SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" +
" WHERE room_id = $1 AND user_id = $2 AND membership IN ($3)" +
" ORDER BY stream_pos DESC" +
" LIMIT 1"
const selectMembershipCountSQL = "" + const selectMembershipCountSQL = "" +
"SELECT COUNT(*) FROM (" + "SELECT COUNT(*) FROM (" +
" SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" + " SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" +
@ -111,22 +104,6 @@ func (s *membershipsStatements) UpsertMembership(
return err return err
} }
func (s *membershipsStatements) SelectMembership(
ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string,
) (eventID string, streamPos, topologyPos types.StreamPosition, err error) {
params := []interface{}{roomID, userID}
for _, membership := range memberships {
params = append(params, membership)
}
orig := strings.Replace(selectMembershipSQL, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1)
stmt, err := s.db.Prepare(orig)
if err != nil {
return "", 0, 0, err
}
err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos)
return
}
func (s *membershipsStatements) SelectMembershipCount( func (s *membershipsStatements) SelectMembershipCount(
ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition, ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition,
) (count int, err error) { ) (count int, err error) {

View file

@ -58,7 +58,7 @@ const insertEventSQL = "" +
"ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)" "ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)"
const selectEventsSQL = "" + const selectEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1" "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id IN ($1)"
const selectRecentEventsSQL = "" + const selectRecentEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
@ -111,9 +111,8 @@ const selectContextAfterEventSQL = "" +
type outputRoomEventsStatements struct { type outputRoomEventsStatements struct {
db *sql.DB db *sql.DB
streamIDStatements *streamIDStatements streamIDStatements *StreamIDStatements
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt
selectMaxEventIDStmt *sql.Stmt selectMaxEventIDStmt *sql.Stmt
updateEventJSONStmt *sql.Stmt updateEventJSONStmt *sql.Stmt
deleteEventsForRoomStmt *sql.Stmt deleteEventsForRoomStmt *sql.Stmt
@ -122,7 +121,7 @@ type outputRoomEventsStatements struct {
selectContextAfterEventStmt *sql.Stmt selectContextAfterEventStmt *sql.Stmt
} }
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) { func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Events, error) {
s := &outputRoomEventsStatements{ s := &outputRoomEventsStatements{
db: db, db: db,
streamIDStatements: streamID, streamIDStatements: streamID,
@ -133,7 +132,6 @@ func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Even
} }
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertEventStmt, insertEventSQL}, {&s.insertEventStmt, insertEventSQL},
{&s.selectEventsStmt, selectEventsSQL},
{&s.selectMaxEventIDStmt, selectMaxEventIDSQL}, {&s.selectMaxEventIDStmt, selectMaxEventIDSQL},
{&s.updateEventJSONStmt, updateEventJSONSQL}, {&s.updateEventJSONStmt, updateEventJSONSQL},
{&s.deleteEventsForRoomStmt, deleteEventsForRoomSQL}, {&s.deleteEventsForRoomStmt, deleteEventsForRoomSQL},
@ -170,7 +168,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
s.db, txn, stmtSQL, inputParams, s.db, txn, stmtSQL, inputParams,
stateFilter.Senders, stateFilter.NotSenders, stateFilter.Senders, stateFilter.NotSenders,
stateFilter.Types, stateFilter.NotTypes, stateFilter.Types, stateFilter.NotTypes,
nil, stateFilter.Limit, FilterOrderAsc, nil, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderAsc,
) )
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err) return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err)
@ -279,7 +277,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
// Parse content as JSON and search for an "url" key // Parse content as JSON and search for an "url" key
containsURL := false containsURL := false
var content map[string]interface{} var content map[string]interface{}
if json.Unmarshal(event.Content(), &content) != nil { if json.Unmarshal(event.Content(), &content) == nil {
// Set containsURL to true if url is present // Set containsURL to true if url is present
_, containsURL = content["url"] _, containsURL = content["url"]
} }
@ -347,7 +345,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
}, },
eventFilter.Senders, eventFilter.NotSenders, eventFilter.Senders, eventFilter.NotSenders,
eventFilter.Types, eventFilter.NotTypes, eventFilter.Types, eventFilter.NotTypes,
nil, eventFilter.Limit+1, FilterOrderDesc, nil, eventFilter.ContainsURL, eventFilter.Limit+1, FilterOrderDesc,
) )
if err != nil { if err != nil {
return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err) return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err)
@ -395,7 +393,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
}, },
eventFilter.Senders, eventFilter.NotSenders, eventFilter.Senders, eventFilter.NotSenders,
eventFilter.Types, eventFilter.NotTypes, eventFilter.Types, eventFilter.NotTypes,
nil, eventFilter.Limit, FilterOrderAsc, nil, eventFilter.ContainsURL, eventFilter.Limit, FilterOrderAsc,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("s.prepareWithFilters: %w", err) return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
@ -421,21 +419,50 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
// selectEvents returns the events for the given event IDs. If an event is // selectEvents returns the events for the given event IDs. If an event is
// missing from the database, it will be omitted. // missing from the database, it will be omitted.
func (s *outputRoomEventsStatements) SelectEvents( func (s *outputRoomEventsStatements) SelectEvents(
ctx context.Context, txn *sql.Tx, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool,
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, error) {
var returnEvents []types.StreamEvent iEventIDs := make([]interface{}, len(eventIDs))
stmt := sqlutil.TxStmt(txn, s.selectEventsStmt) for i := range eventIDs {
for _, eventID := range eventIDs { iEventIDs[i] = eventIDs[i]
rows, err := stmt.QueryContext(ctx, eventID)
if err != nil {
return nil, err
}
if streamEvents, err := rowsToStreamEvents(rows); err == nil {
returnEvents = append(returnEvents, streamEvents...)
}
internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
} }
return returnEvents, nil selectSQL := strings.Replace(selectEventsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1)
if filter == nil {
filter = &gomatrixserverlib.RoomEventFilter{Limit: 20}
}
stmt, params, err := prepareWithFilters(
s.db, txn, selectSQL, iEventIDs,
filter.Senders, filter.NotSenders,
filter.Types, filter.NotTypes,
nil, filter.ContainsURL, filter.Limit, FilterOrderAsc,
)
if err != nil {
return nil, err
}
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
streamEvents, err := rowsToStreamEvents(rows)
if err != nil {
return nil, err
}
if preserveOrder {
var returnEvents []types.StreamEvent
eventMap := make(map[string]types.StreamEvent)
for _, ev := range streamEvents {
eventMap[ev.EventID()] = ev
}
for _, eventID := range eventIDs {
ev, ok := eventMap[eventID]
if ok {
returnEvents = append(returnEvents, ev)
}
}
return returnEvents, nil
}
return streamEvents, nil
} }
func (s *outputRoomEventsStatements) DeleteEventsForRoom( func (s *outputRoomEventsStatements) DeleteEventsForRoom(
@ -507,7 +534,7 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
}, },
filter.Senders, filter.NotSenders, filter.Senders, filter.NotSenders,
filter.Types, filter.NotTypes, filter.Types, filter.NotTypes,
nil, filter.Limit, FilterOrderDesc, nil, filter.ContainsURL, filter.Limit, FilterOrderDesc,
) )
rows, err := stmt.QueryContext(ctx, params...) rows, err := stmt.QueryContext(ctx, params...)
@ -543,7 +570,7 @@ func (s *outputRoomEventsStatements) SelectContextAfterEvent(
}, },
filter.Senders, filter.NotSenders, filter.Senders, filter.NotSenders,
filter.Types, filter.NotTypes, filter.Types, filter.NotTypes,
nil, filter.Limit, FilterOrderAsc, nil, filter.ContainsURL, filter.Limit, FilterOrderAsc,
) )
rows, err := stmt.QueryContext(ctx, params...) rows, err := stmt.QueryContext(ctx, params...)

View file

@ -78,7 +78,6 @@ type outputRoomEventsTopologyStatements struct {
selectEventIDsInRangeDESCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt
selectPositionInTopologyStmt *sql.Stmt selectPositionInTopologyStmt *sql.Stmt
selectMaxPositionInTopologyStmt *sql.Stmt selectMaxPositionInTopologyStmt *sql.Stmt
deleteTopologyForRoomStmt *sql.Stmt
selectStreamToTopologicalPositionAscStmt *sql.Stmt selectStreamToTopologicalPositionAscStmt *sql.Stmt
selectStreamToTopologicalPositionDescStmt *sql.Stmt selectStreamToTopologicalPositionDescStmt *sql.Stmt
} }
@ -191,10 +190,3 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
return return
} }
func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID)
return err
}

View file

@ -66,7 +66,7 @@ const selectMaxPeekIDSQL = "" +
type peekStatements struct { type peekStatements struct {
db *sql.DB db *sql.DB
streamIDStatements *streamIDStatements streamIDStatements *StreamIDStatements
insertPeekStmt *sql.Stmt insertPeekStmt *sql.Stmt
deletePeekStmt *sql.Stmt deletePeekStmt *sql.Stmt
deletePeeksStmt *sql.Stmt deletePeeksStmt *sql.Stmt
@ -75,7 +75,7 @@ type peekStatements struct {
selectMaxPeekIDStmt *sql.Stmt selectMaxPeekIDStmt *sql.Stmt
} }
func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks, error) { func NewSqlitePeeksTable(db *sql.DB, streamID *StreamIDStatements) (tables.Peeks, error) {
_, err := db.Exec(peeksSchema) _, err := db.Exec(peeksSchema)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -75,7 +75,7 @@ const selectPresenceAfter = "" +
type presenceStatements struct { type presenceStatements struct {
db *sql.DB db *sql.DB
streamIDStatements *streamIDStatements streamIDStatements *StreamIDStatements
upsertPresenceStmt *sql.Stmt upsertPresenceStmt *sql.Stmt
upsertPresenceFromSyncStmt *sql.Stmt upsertPresenceFromSyncStmt *sql.Stmt
selectPresenceForUsersStmt *sql.Stmt selectPresenceForUsersStmt *sql.Stmt
@ -83,7 +83,7 @@ type presenceStatements struct {
selectPresenceAfterStmt *sql.Stmt selectPresenceAfterStmt *sql.Stmt
} }
func NewSqlitePresenceTable(db *sql.DB, streamID *streamIDStatements) (*presenceStatements, error) { func NewSqlitePresenceTable(db *sql.DB, streamID *StreamIDStatements) (*presenceStatements, error) {
_, err := db.Exec(presenceSchema) _, err := db.Exec(presenceSchema)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -59,13 +59,13 @@ const selectMaxReceiptIDSQL = "" +
type receiptStatements struct { type receiptStatements struct {
db *sql.DB db *sql.DB
streamIDStatements *streamIDStatements streamIDStatements *StreamIDStatements
upsertReceipt *sql.Stmt upsertReceipt *sql.Stmt
selectRoomReceipts *sql.Stmt selectRoomReceipts *sql.Stmt
selectMaxReceiptID *sql.Stmt selectMaxReceiptID *sql.Stmt
} }
func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Receipts, error) { func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Receipts, error) {
_, err := db.Exec(receiptsSchema) _, err := db.Exec(receiptsSchema)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -32,12 +32,12 @@ const increaseStreamIDStmt = "" +
"UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" + "UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" +
" RETURNING stream_id" " RETURNING stream_id"
type streamIDStatements struct { type StreamIDStatements struct {
db *sql.DB db *sql.DB
increaseStreamIDStmt *sql.Stmt increaseStreamIDStmt *sql.Stmt
} }
func (s *streamIDStatements) prepare(db *sql.DB) (err error) { func (s *StreamIDStatements) Prepare(db *sql.DB) (err error) {
s.db = db s.db = db
_, err = db.Exec(streamIDTableSchema) _, err = db.Exec(streamIDTableSchema)
if err != nil { if err != nil {
@ -49,31 +49,31 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
return return
} }
func (s *streamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { func (s *StreamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "global").Scan(&pos) err = increaseStmt.QueryRowContext(ctx, "global").Scan(&pos)
return return
} }
func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { func (s *StreamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "receipt").Scan(&pos) err = increaseStmt.QueryRowContext(ctx, "receipt").Scan(&pos)
return return
} }
func (s *streamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { func (s *StreamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "invite").Scan(&pos) err = increaseStmt.QueryRowContext(ctx, "invite").Scan(&pos)
return return
} }
func (s *streamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { func (s *StreamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "accountdata").Scan(&pos) err = increaseStmt.QueryRowContext(ctx, "accountdata").Scan(&pos)
return return
} }
func (s *streamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { func (s *StreamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "presence").Scan(&pos) err = increaseStmt.QueryRowContext(ctx, "presence").Scan(&pos)
return return

View file

@ -30,7 +30,7 @@ type SyncServerDatasource struct {
shared.Database shared.Database
db *sql.DB db *sql.DB
writer sqlutil.Writer writer sqlutil.Writer
streamID streamIDStatements streamID StreamIDStatements
} }
// NewDatabase creates a new sync server database // NewDatabase creates a new sync server database
@ -49,7 +49,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
} }
func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) { func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) {
if err = d.streamID.prepare(d.db); err != nil { if err = d.streamID.Prepare(d.db); err != nil {
return err return err
} }
accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID) accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID)

View file

@ -3,6 +3,7 @@ package storage_test
import ( import (
"context" "context"
"fmt" "fmt"
"reflect"
"testing" "testing"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
@ -38,7 +39,7 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
if err != nil { if err != nil {
t.Fatalf("WriteEvent failed: %s", err) t.Fatalf("WriteEvent failed: %s", err)
} }
fmt.Println("Event ID", ev.EventID(), " spos=", pos, "depth=", ev.Depth()) t.Logf("Event ID %s spos=%v depth=%v", ev.EventID(), pos, ev.Depth())
positions = append(positions, pos) positions = append(positions, pos)
} }
return return
@ -46,7 +47,6 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
func TestWriteEvents(t *testing.T) { func TestWriteEvents(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
t.Parallel()
alice := test.NewUser() alice := test.NewUser()
r := test.NewRoom(t, alice) r := test.NewRoom(t, alice)
db, close := MustCreateDatabase(t, dbType) db, close := MustCreateDatabase(t, dbType)
@ -61,84 +61,84 @@ func TestRecentEventsPDU(t *testing.T) {
db, close := MustCreateDatabase(t, dbType) db, close := MustCreateDatabase(t, dbType)
defer close() defer close()
alice := test.NewUser() alice := test.NewUser()
var filter gomatrixserverlib.RoomEventFilter // dummy room to make sure SQL queries are filtering on room ID
filter.Limit = 100 MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
// actual test room
r := test.NewRoom(t, alice) r := test.NewRoom(t, alice)
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hi"}) r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hi"})
events := r.Events() events := r.Events()
positions := MustWriteEvents(t, db, events) positions := MustWriteEvents(t, db, events)
// dummy room to make sure SQL queries are filtering on room ID
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
latest, err := db.MaxStreamPositionForPDUs(ctx) latest, err := db.MaxStreamPositionForPDUs(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err) t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err)
} }
testCases := []struct { testCases := []struct {
Name string Name string
From types.StreamPosition From types.StreamPosition
To types.StreamPosition To types.StreamPosition
WantEvents []*gomatrixserverlib.HeaderedEvent Limit int
WantLimited bool ReverseOrder bool
WantEvents []*gomatrixserverlib.HeaderedEvent
WantLimited bool
}{ }{
// The purpose of this test is to make sure that incremental syncs are including up to the latest events. // The purpose of this test is to make sure that incremental syncs are including up to the latest events.
// It's a basic sanity test that sync works. It creates a `since` token that is on the penultimate event. // It's a basic sanity test that sync works. It creates a streaming position that is on the penultimate event.
// It makes sure the response includes the final event. // It makes sure the response includes the final event.
{ {
Name: "IncrementalSync penultimate", Name: "penultimate",
From: positions[len(positions)-2], // pretend we are at the penultimate event From: positions[len(positions)-2], // pretend we are at the penultimate event
To: latest, To: latest,
Limit: 100,
WantEvents: events[len(events)-1:], WantEvents: events[len(events)-1:],
WantLimited: false, WantLimited: false,
}, },
/* // The purpose of this test is to check that limits can be applied and work.
// The purpose of this test is to check that passing a `numRecentEventsPerRoom` correctly limits the // This is critical for big rooms hence the test here.
// number of returned events. This is critical for big rooms hence the test here. {
{ Name: "limited",
Name: "IncrementalSync limited", From: 0,
DoSync: func() (*types.Response, error) { To: latest,
from := types.StreamingToken{ // pretend we are 10 events behind Limit: 1,
PDUPosition: positions[len(positions)-11], WantEvents: events[len(events)-1:],
} WantLimited: true,
res := types.NewResponse() },
// limit is set to 5 // The purpose of this test is to check that we can return every event with a high
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) // enough limit
}, {
// want the last 5 events, NOT the last 10. Name: "large limited",
WantTimeline: events[len(events)-5:], From: 0,
}, To: latest,
// The purpose of this test is to check that CompleteSync returns all the current state as well as Limit: 100,
// honouring the `numRecentEventsPerRoom` value WantEvents: events,
{ WantLimited: false,
Name: "CompleteSync limited", },
DoSync: func() (*types.Response, error) { // The purpose of this test is to check that we can return events in reverse order
res := types.NewResponse() {
// limit set to 5 Name: "reverse",
return db.CompleteSync(ctx, res, testUserDeviceA, 5) From: positions[len(positions)-3], // 2 events back
}, To: latest,
// want the last 5 events Limit: 100,
WantTimeline: events[len(events)-5:], ReverseOrder: true,
// want all state for the room WantEvents: test.Reversed(events[len(events)-2:]),
WantState: state, WantLimited: false,
}, },
// The purpose of this test is to check that CompleteSync can return everything with a high enough
// `numRecentEventsPerRoom`.
{
Name: "CompleteSync",
DoSync: func() (*types.Response, error) {
res := types.NewResponse()
return db.CompleteSync(ctx, res, testUserDeviceA, len(events)+1)
},
WantTimeline: events,
// We want no state at all as that field in /sync is the delta between the token (beginning of time)
// and the START of the timeline.
}, */
} }
for _, tc := range testCases { for i := range testCases {
tc := testCases[i]
t.Run(tc.Name, func(st *testing.T) { t.Run(tc.Name, func(st *testing.T) {
var filter gomatrixserverlib.RoomEventFilter
filter.Limit = tc.Limit
gotEvents, limited, err := db.RecentEvents(ctx, r.ID, types.Range{ gotEvents, limited, err := db.RecentEvents(ctx, r.ID, types.Range{
From: tc.From, From: tc.From,
To: tc.To, To: tc.To,
}, &filter, true, true) }, &filter, !tc.ReverseOrder, true)
if err != nil { if err != nil {
st.Fatalf("failed to do sync: %s", err) st.Fatalf("failed to do sync: %s", err)
} }
@ -148,100 +148,49 @@ func TestRecentEventsPDU(t *testing.T) {
if len(gotEvents) != len(tc.WantEvents) { if len(gotEvents) != len(tc.WantEvents) {
st.Errorf("got %d events, want %d", len(gotEvents), len(tc.WantEvents)) st.Errorf("got %d events, want %d", len(gotEvents), len(tc.WantEvents))
} }
for j := range gotEvents {
if !reflect.DeepEqual(gotEvents[j].JSON(), tc.WantEvents[j].JSON()) {
st.Errorf("event %d got %s want %s", j, string(gotEvents[j].JSON()), string(tc.WantEvents[j].JSON()))
}
}
}) })
} }
}) })
} }
/*
func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
t.Parallel()
db := MustCreateDatabase(t)
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
positions := MustWriteEvents(t, db, events)
latest, err := db.SyncPosition(ctx)
if err != nil {
t.Fatalf("failed to get SyncPosition: %s", err)
}
from := types.StreamingToken{
PDUPosition: positions[len(positions)-2],
}
res := types.NewResponse()
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
if err != nil {
t.Fatalf("failed to IncrementalSync with latest token")
}
roomRes, ok := res.Rooms.Join[testRoomID]
if !ok {
t.Fatalf("IncrementalSync response missing room %s - response: %+v", testRoomID, res)
}
// returns the last event "Message 10"
assertEventsEqual(t, "IncrementalSync Timeline", false, roomRes.Timeline.Events, reversed(events[len(events)-1:]))
prev := roomRes.Timeline.PrevBatch.String()
if prev == "" {
t.Fatalf("IncrementalSync expected prev_batch token")
}
prevBatchToken, err := types.NewTopologyTokenFromString(prev)
if err != nil {
t.Fatalf("failed to NewTopologyTokenFromString : %s", err)
}
// backpaginate 5 messages starting at the latest position.
// head towards the beginning of time
to := types.TopologyToken{}
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &prevBatchToken, &to, testRoomID, 5, true)
if err != nil {
t.Fatalf("GetEventsInRange returned an error: %s", err)
}
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-6:len(events)-1]))
}
// The purpose of this test is to ensure that backfill does indeed go backwards, using a stream token.
func TestGetEventsInRangeWithStreamToken(t *testing.T) {
t.Parallel()
db := MustCreateDatabase(t)
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
MustWriteEvents(t, db, events)
latest, err := db.SyncPosition(ctx)
if err != nil {
t.Fatalf("failed to get SyncPosition: %s", err)
}
// head towards the beginning of time
to := types.StreamingToken{}
// backpaginate 5 messages starting at the latest position.
paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true)
if err != nil {
t.Fatalf("GetEventsInRange returned an error: %s", err)
}
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:]))
}
// The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token // The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token
func TestGetEventsInRangeWithTopologyToken(t *testing.T) { func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
t.Parallel() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db := MustCreateDatabase(t) db, close := MustCreateDatabase(t, dbType)
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) defer close()
MustWriteEvents(t, db, events) alice := test.NewUser()
from, err := db.MaxTopologicalPosition(ctx, testRoomID) r := test.NewRoom(t, alice)
if err != nil { for i := 0; i < 10; i++ {
t.Fatalf("failed to get MaxTopologicalPosition: %s", err) r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)})
} }
// head towards the beginning of time events := r.Events()
to := types.TopologyToken{} _ = MustWriteEvents(t, db, events)
// backpaginate 5 messages starting at the latest position. from, err := db.MaxTopologicalPosition(ctx, r.ID)
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, testRoomID, 5, true) if err != nil {
if err != nil { t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
t.Fatalf("GetEventsInRange returned an error: %s", err) }
} t.Logf("max topo pos = %+v", from)
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll) // head towards the beginning of time
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:])) to := types.TopologyToken{}
// backpaginate 5 messages starting at the latest position.
filter := &gomatrixserverlib.RoomEventFilter{Limit: 5}
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true)
if err != nil {
t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
}
gots := db.StreamEventsToEvents(nil, paginatedEvents)
test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
})
} }
/*
// The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth. // The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth.
// For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent // For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent
// will appear FIRST when going backwards. This test creates a DAG like: // will appear FIRST when going backwards. This test creates a DAG like:
@ -651,12 +600,4 @@ func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *typ
tok.Decrement() tok.Decrement()
return &tok return &tok
} }
func reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
for i := 0; i < len(in); i++ {
out[i] = in[len(in)-i-1]
}
return out
}
*/ */

View file

@ -59,7 +59,7 @@ type Events interface {
SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
// SelectEarlyEvents returns the earliest events in the given room. // SelectEarlyEvents returns the earliest events in the given room.
SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error) SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error)
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error) SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool) ([]types.StreamEvent, error)
UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error
// DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely. // DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely.
DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error) DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
@ -84,8 +84,6 @@ type Topology interface {
SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error) SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error)
// SelectMaxPositionInTopology returns the event which has the highest depth, and if there are multiple, the event with the highest stream position. // SelectMaxPositionInTopology returns the event which has the highest depth, and if there are multiple, the event with the highest stream position.
SelectMaxPositionInTopology(ctx context.Context, txn *sql.Tx, roomID string) (depth types.StreamPosition, spos types.StreamPosition, err error) SelectMaxPositionInTopology(ctx context.Context, txn *sql.Tx, roomID string) (depth types.StreamPosition, spos types.StreamPosition, err error)
// DeleteTopologyForRoom removes all topological information for a room. This should only be done when removing the room entirely.
DeleteTopologyForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
// SelectStreamToTopologicalPosition converts a stream position to a topological position by finding the nearest topological position in the room. // SelectStreamToTopologicalPosition converts a stream position to a topological position by finding the nearest topological position in the room.
SelectStreamToTopologicalPosition(ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, forward bool) (topoPos types.StreamPosition, err error) SelectStreamToTopologicalPosition(ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, forward bool) (topoPos types.StreamPosition, err error)
} }
@ -132,8 +130,6 @@ type BackwardsExtremities interface {
SelectBackwardExtremitiesForRoom(ctx context.Context, roomID string) (bwExtrems map[string][]string, err error) SelectBackwardExtremitiesForRoom(ctx context.Context, roomID string) (bwExtrems map[string][]string, err error)
// DeleteBackwardExtremity removes a backwards extremity for a room, if one existed. // DeleteBackwardExtremity removes a backwards extremity for a room, if one existed.
DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error) DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error)
// DeleteBackwardExtremitiesFoorRoomID removes all backward extremities for a room. This should only be done when removing the room entirely.
DeleteBackwardExtremitiesForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
} }
// SendToDevice tracks send-to-device messages which are sent to individual // SendToDevice tracks send-to-device messages which are sent to individual
@ -173,7 +169,6 @@ type Receipts interface {
type Memberships interface { type Memberships interface {
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string) (eventID string, streamPos, topologyPos types.StreamPosition, err error)
SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error) SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error)
} }

View file

@ -0,0 +1,105 @@
package tables_test
import (
"context"
"database/sql"
"fmt"
"reflect"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
)
func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events, *sql.DB, func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
})
if err != nil {
t.Fatalf("failed to open db: %s", err)
}
var tab tables.Events
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresEventsTable(db)
case test.DBTypeSQLite:
var stream sqlite3.StreamIDStatements
if err = stream.Prepare(db); err != nil {
t.Fatalf("failed to prepare stream stmts: %s", err)
}
tab, err = sqlite3.NewSqliteEventsTable(db, &stream)
}
if err != nil {
t.Fatalf("failed to make new table: %s", err)
}
return tab, db, close
}
func TestOutputRoomEventsTable(t *testing.T) {
ctx := context.Background()
alice := test.NewUser()
room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, db, close := newOutputRoomEventsTable(t, dbType)
defer close()
events := room.Events()
err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
for _, ev := range events {
_, err := tab.InsertEvent(ctx, txn, ev, nil, nil, nil, false)
if err != nil {
return fmt.Errorf("failed to InsertEvent: %s", err)
}
}
// order = 2,0,3,1
wantEventIDs := []string{
events[2].EventID(), events[0].EventID(), events[3].EventID(), events[1].EventID(),
}
gotEvents, err := tab.SelectEvents(ctx, txn, wantEventIDs, nil, true)
if err != nil {
return fmt.Errorf("failed to SelectEvents: %s", err)
}
gotEventIDs := make([]string, len(gotEvents))
for i := range gotEvents {
gotEventIDs[i] = gotEvents[i].EventID()
}
if !reflect.DeepEqual(gotEventIDs, wantEventIDs) {
return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventIDs)
}
// Test that contains_url is correctly populated
urlEv := room.CreateEvent(t, alice, "m.text", map[string]interface{}{
"body": "test.txt",
"url": "mxc://test.txt",
})
if _, err = tab.InsertEvent(ctx, txn, urlEv, nil, nil, nil, false); err != nil {
return fmt.Errorf("failed to InsertEvent: %s", err)
}
wantEventID := []string{urlEv.EventID()}
t := true
gotEvents, err = tab.SelectEvents(ctx, txn, wantEventID, &gomatrixserverlib.RoomEventFilter{Limit: 1, ContainsURL: &t}, true)
if err != nil {
return fmt.Errorf("failed to SelectEvents: %s", err)
}
gotEventIDs = make([]string, len(gotEvents))
for i := range gotEvents {
gotEventIDs[i] = gotEvents[i].EventID()
}
if !reflect.DeepEqual(gotEventIDs, wantEventID) {
return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventID)
}
return nil
})
if err != nil {
t.Fatalf("err: %s", err)
}
})
}

View file

@ -0,0 +1,91 @@
package tables_test
import (
"context"
"database/sql"
"fmt"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test"
)
func newTopologyTable(t *testing.T, dbType test.DBType) (tables.Topology, *sql.DB, func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
})
if err != nil {
t.Fatalf("failed to open db: %s", err)
}
var tab tables.Topology
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresTopologyTable(db)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSqliteTopologyTable(db)
}
if err != nil {
t.Fatalf("failed to make new table: %s", err)
}
return tab, db, close
}
func TestTopologyTable(t *testing.T) {
ctx := context.Background()
alice := test.NewUser()
room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, db, close := newTopologyTable(t, dbType)
defer close()
events := room.Events()
err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
var highestPos types.StreamPosition
for i, ev := range events {
topoPos, err := tab.InsertEventInTopology(ctx, txn, ev, types.StreamPosition(i))
if err != nil {
return fmt.Errorf("failed to InsertEventInTopology: %s", err)
}
// topo pos = depth, depth starts at 1, hence 1+i
if topoPos != types.StreamPosition(1+i) {
return fmt.Errorf("got topo pos %d want %d", topoPos, 1+i)
}
highestPos = topoPos + 1
}
// check ordering works without limit
eventIDs, err := tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, true)
if err != nil {
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
}
test.AssertEventIDsEqual(t, eventIDs, events[:])
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, false)
if err != nil {
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
}
test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[:]))
// check ordering works with limit
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, true)
if err != nil {
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
}
test.AssertEventIDsEqual(t, eventIDs, events[:3])
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, false)
if err != nil {
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
}
test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[len(events)-3:]))
return nil
})
if err != nil {
t.Fatalf("err: %s", err)
}
})
}

View file

@ -3,9 +3,11 @@ package streams
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"sync" "sync"
"time" "time"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -26,7 +28,8 @@ type PDUStreamProvider struct {
tasks chan func() tasks chan func()
workers atomic.Int32 workers atomic.Int32
userAPI userapi.UserInternalAPI // userID+deviceID -> lazy loading cache
lazyLoadCache *caching.LazyLoadCache
} }
func (p *PDUStreamProvider) worker() { func (p *PDUStreamProvider) worker() {
@ -188,7 +191,7 @@ func (p *PDUStreamProvider) IncrementalSync(
newPos = from newPos = from
for _, delta := range stateDeltas { for _, delta := range stateDeltas {
var pos types.StreamPosition var pos types.StreamPosition
if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, r, delta, &eventFilter, req.Response); err != nil { if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, r, delta, &eventFilter, &stateFilter, req.Response); err != nil {
req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed") req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed")
return to return to
} }
@ -203,12 +206,14 @@ func (p *PDUStreamProvider) IncrementalSync(
return newPos return newPos
} }
// nolint:gocyclo
func (p *PDUStreamProvider) addRoomDeltaToResponse( func (p *PDUStreamProvider) addRoomDeltaToResponse(
ctx context.Context, ctx context.Context,
device *userapi.Device, device *userapi.Device,
r types.Range, r types.Range,
delta types.StateDelta, delta types.StateDelta,
eventFilter *gomatrixserverlib.RoomEventFilter, eventFilter *gomatrixserverlib.RoomEventFilter,
stateFilter *gomatrixserverlib.StateFilter,
res *types.Response, res *types.Response,
) (types.StreamPosition, error) { ) (types.StreamPosition, error) {
if delta.MembershipPos > 0 && delta.Membership == gomatrixserverlib.Leave { if delta.MembershipPos > 0 && delta.Membership == gomatrixserverlib.Leave {
@ -225,13 +230,16 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
eventFilter, true, true, eventFilter, true, true,
) )
if err != nil { if err != nil {
return r.From, err if err == sql.ErrNoRows {
return r.To, nil
}
return r.From, fmt.Errorf("p.DB.RecentEvents: %w", err)
} }
recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents) recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents)
delta.StateEvents = removeDuplicates(delta.StateEvents, recentEvents) // roll back delta.StateEvents = removeDuplicates(delta.StateEvents, recentEvents) // roll back
prevBatch, err := p.DB.GetBackwardTopologyPos(ctx, recentStreamEvents) prevBatch, err := p.DB.GetBackwardTopologyPos(ctx, recentStreamEvents)
if err != nil { if err != nil {
return r.From, err return r.From, fmt.Errorf("p.DB.GetBackwardTopologyPos: %w", err)
} }
// If we didn't return any events at all then don't bother doing anything else. // If we didn't return any events at all then don't bother doing anything else.
@ -247,7 +255,8 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
// room that were returned. // room that were returned.
latestPosition := r.To latestPosition := r.To
updateLatestPosition := func(mostRecentEventID string) { updateLatestPosition := func(mostRecentEventID string) {
if _, pos, err := p.DB.PositionInTopology(ctx, mostRecentEventID); err == nil { var pos types.StreamPosition
if _, pos, err = p.DB.PositionInTopology(ctx, mostRecentEventID); err == nil {
switch { switch {
case r.Backwards && pos > latestPosition: case r.Backwards && pos > latestPosition:
fallthrough fallthrough
@ -263,6 +272,16 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID()) updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID())
} }
if stateFilter.LazyLoadMembers {
delta.StateEvents, err = p.lazyLoadMembers(
ctx, delta.RoomID, true, limited, stateFilter.IncludeRedundantMembers,
device, recentEvents, delta.StateEvents,
)
if err != nil && err != sql.ErrNoRows {
return r.From, fmt.Errorf("p.lazyLoadMembers: %w", err)
}
}
hasMembershipChange := false hasMembershipChange := false
for _, recentEvent := range recentStreamEvents { for _, recentEvent := range recentStreamEvents {
if recentEvent.Type() == gomatrixserverlib.MRoomMember && recentEvent.StateKey() != nil { if recentEvent.Type() == gomatrixserverlib.MRoomMember && recentEvent.StateKey() != nil {
@ -322,12 +341,16 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
wantFullState bool, wantFullState bool,
device *userapi.Device, device *userapi.Device,
) (jr *types.JoinResponse, err error) { ) (jr *types.JoinResponse, err error) {
jr = types.NewJoinResponse()
// TODO: When filters are added, we may need to call this multiple times to get enough events. // TODO: When filters are added, we may need to call this multiple times to get enough events.
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
recentStreamEvents, limited, err := p.DB.RecentEvents( recentStreamEvents, limited, err := p.DB.RecentEvents(
ctx, roomID, r, eventFilter, true, true, ctx, roomID, r, eventFilter, true, true,
) )
if err != nil { if err != nil {
if err == sql.ErrNoRows {
return jr, nil
}
return return
} }
@ -402,7 +425,20 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
// "Can sync a room with a message with a transaction id" - which does a complete sync to check. // "Can sync a room with a message with a transaction id" - which does a complete sync to check.
recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents) recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents) stateEvents = removeDuplicates(stateEvents, recentEvents)
jr = types.NewJoinResponse()
if stateFilter.LazyLoadMembers {
if err != nil {
return nil, err
}
stateEvents, err = p.lazyLoadMembers(ctx, roomID,
false, limited, stateFilter.IncludeRedundantMembers,
device, recentEvents, stateEvents,
)
if err != nil && err != sql.ErrNoRows {
return nil, err
}
}
jr.Summary.JoinedMemberCount = &joinedCount jr.Summary.JoinedMemberCount = &joinedCount
jr.Summary.InvitedMemberCount = &invitedCount jr.Summary.InvitedMemberCount = &invitedCount
jr.Timeline.PrevBatch = prevBatch jr.Timeline.PrevBatch = prevBatch
@ -412,6 +448,69 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
return jr, nil return jr, nil
} }
func (p *PDUStreamProvider) lazyLoadMembers(
ctx context.Context, roomID string,
incremental, limited, includeRedundant bool,
device *userapi.Device,
timelineEvents, stateEvents []*gomatrixserverlib.HeaderedEvent,
) ([]*gomatrixserverlib.HeaderedEvent, error) {
if len(timelineEvents) == 0 {
return stateEvents, nil
}
// Work out which memberships to include
timelineUsers := make(map[string]struct{})
if !incremental {
timelineUsers[device.UserID] = struct{}{}
}
// Add all users the client doesn't know about yet to a list
for _, event := range timelineEvents {
// Membership is not yet cached, add it to the list
if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, event.Sender()); !ok {
timelineUsers[event.Sender()] = struct{}{}
}
}
// Preallocate with the same amount, even if it will end up with fewer values
newStateEvents := make([]*gomatrixserverlib.HeaderedEvent, 0, len(stateEvents))
// Remove existing membership events we don't care about, e.g. users not in the timeline.events
for _, event := range stateEvents {
if event.Type() == gomatrixserverlib.MRoomMember && event.StateKey() != nil {
// If this is a gapped incremental sync, we still want this membership
isGappedIncremental := limited && incremental
// We want this users membership event, keep it in the list
_, ok := timelineUsers[event.Sender()]
wantMembership := ok || isGappedIncremental
if wantMembership {
newStateEvents = append(newStateEvents, event)
if !includeRedundant {
p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, event.Sender(), event.EventID())
}
delete(timelineUsers, event.Sender())
}
} else {
newStateEvents = append(newStateEvents, event)
}
}
wantUsers := make([]string, 0, len(timelineUsers))
for userID := range timelineUsers {
wantUsers = append(wantUsers, userID)
}
// Query missing membership events
memberships, err := p.DB.GetStateEventsForRoom(ctx, roomID, &gomatrixserverlib.StateFilter{
Limit: 100,
Senders: &wantUsers,
Types: &[]string{gomatrixserverlib.MRoomMember},
})
if err != nil {
return stateEvents, err
}
// cache the membership events
for _, membership := range memberships {
p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, membership.Sender(), membership.EventID())
}
stateEvents = append(newStateEvents, memberships...)
return stateEvents, nil
}
// addIgnoredUsersToFilter adds ignored users to the eventfilter and // addIgnoredUsersToFilter adds ignored users to the eventfilter and
// the syncreq itself for further use in streams. // the syncreq itself for further use in streams.
func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error { func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error {
@ -423,8 +522,12 @@ func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *ty
return err return err
} }
req.IgnoredUsers = *ignores req.IgnoredUsers = *ignores
userList := make([]string, 0, len(ignores.List))
for userID := range ignores.List { for userID := range ignores.List {
eventFilter.NotSenders = append(eventFilter.NotSenders, userID) userList = append(userList, userID)
}
if len(userList) > 0 {
eventFilter.NotSenders = &userList
} }
return nil return nil
} }

View file

@ -27,12 +27,12 @@ type Streams struct {
func NewSyncStreamProviders( func NewSyncStreamProviders(
d storage.Database, userAPI userapi.UserInternalAPI, d storage.Database, userAPI userapi.UserInternalAPI,
rsAPI rsapi.RoomserverInternalAPI, keyAPI keyapi.KeyInternalAPI, rsAPI rsapi.RoomserverInternalAPI, keyAPI keyapi.KeyInternalAPI,
eduCache *caching.EDUCache, notifier *notifier.Notifier, eduCache *caching.EDUCache, lazyLoadCache *caching.LazyLoadCache, notifier *notifier.Notifier,
) *Streams { ) *Streams {
streams := &Streams{ streams := &Streams{
PDUStreamProvider: &PDUStreamProvider{ PDUStreamProvider: &PDUStreamProvider{
StreamProvider: StreamProvider{DB: d}, StreamProvider: StreamProvider{DB: d},
userAPI: userAPI, lazyLoadCache: lazyLoadCache,
}, },
TypingStreamProvider: &TypingStreamProvider{ TypingStreamProvider: &TypingStreamProvider{
StreamProvider: StreamProvider{DB: d}, StreamProvider: StreamProvider{DB: d},

View file

@ -15,6 +15,7 @@
package sync package sync
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@ -60,10 +61,10 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return nil, fmt.Errorf("gomatrixserverlib.SplitID: %w", err) return nil, fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
} }
if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err != nil { if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err != nil && err != sql.ErrNoRows {
util.GetLogger(req.Context()).WithError(err).Error("syncDB.GetFilter failed") util.GetLogger(req.Context()).WithError(err).Error("syncDB.GetFilter failed")
return nil, fmt.Errorf("syncDB.GetFilter: %w", err) return nil, fmt.Errorf("syncDB.GetFilter: %w", err)
} else { } else if f != nil {
filter = *f filter = *f
} }
} }

View file

@ -57,8 +57,12 @@ func AddPublicRoutes(
} }
eduCache := caching.NewTypingCache() eduCache := caching.NewTypingCache()
lazyLoadCache, err := caching.NewLazyLoadCache()
if err != nil {
logrus.WithError(err).Panicf("failed to create lazy loading cache")
}
notifier := notifier.NewNotifier() notifier := notifier.NewNotifier()
streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, notifier) streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, lazyLoadCache, notifier)
notifier.SetCurrentPosition(streams.Latest(context.Background())) notifier.SetCurrentPosition(streams.Latest(context.Background()))
if err = notifier.Load(context.Background(), syncDB); err != nil { if err = notifier.Load(context.Background(), syncDB); err != nil {
logrus.WithError(err).Panicf("failed to load notifier ") logrus.WithError(err).Panicf("failed to load notifier ")

View file

@ -312,10 +312,10 @@ Inbound federation can return events
Inbound federation can return missing events for world_readable visibility Inbound federation can return missing events for world_readable visibility
Inbound federation can return missing events for invite visibility Inbound federation can return missing events for invite visibility
Inbound federation can get public room list Inbound federation can get public room list
POST /rooms/:room_id/redact/:event_id as power user redacts message PUT /rooms/:room_id/redact/:event_id/:txn_id as power user redacts message
POST /rooms/:room_id/redact/:event_id as original message sender redacts message PUT /rooms/:room_id/redact/:event_id/:txn_id as original message sender redacts message
POST /rooms/:room_id/redact/:event_id as random user does not redact message PUT /rooms/:room_id/redact/:event_id/:txn_id as random user does not redact message
POST /redact disallows redaction of event in different room PUT /redact disallows redaction of event in different room
An event which redacts itself should be ignored An event which redacts itself should be ignored
A pair of events which redact each other should be ignored A pair of events which redact each other should be ignored
Redaction of a redaction redacts the redaction reason Redaction of a redaction redacts the redaction reason
@ -696,4 +696,17 @@ Room state after a rejected message event is the same as before
Room state after a rejected state event is the same as before Room state after a rejected state event is the same as before
Ignore user in existing room Ignore user in existing room
Ignore invite in full sync Ignore invite in full sync
Ignore invite in incremental sync Ignore invite in incremental sync
A filtered timeline reaches its limit
A change to displayname should not result in a full state sync
Can fetch images in room
The only membership state included in an initial sync is for all the senders in the timeline
The only membership state included in an incremental sync is for senders in the timeline
Old members are included in gappy incr LL sync if they start speaking
We do send redundant membership state across incremental syncs if asked
Rejecting invite over federation doesn't break incremental /sync
Gapped incremental syncs include all state changes
Old leaves are present in gapped incremental syncs
Leaves are present in non-gapped incremental syncs
Members from the gap are included in gappy incr LL sync
Presence can be set from sync

View file

@ -15,12 +15,16 @@
package test package test
import ( import (
"crypto/sha256"
"database/sql" "database/sql"
"encoding/hex"
"fmt" "fmt"
"os" "os"
"os/exec" "os/exec"
"os/user" "os/user"
"testing" "testing"
"github.com/lib/pq"
) )
type DBType int type DBType int
@ -30,7 +34,7 @@ var DBTypePostgres DBType = 2
var Quiet = false var Quiet = false
func createLocalDB(dbName string) string { func createLocalDB(dbName string) {
if !Quiet { if !Quiet {
fmt.Println("Note: tests require a postgres install accessible to the current user") fmt.Println("Note: tests require a postgres install accessible to the current user")
} }
@ -43,7 +47,29 @@ func createLocalDB(dbName string) string {
if err != nil && !Quiet { if err != nil && !Quiet {
fmt.Println("createLocalDB returned error:", err) fmt.Println("createLocalDB returned error:", err)
} }
return dbName }
func createRemoteDB(t *testing.T, dbName, user, connStr string) {
db, err := sql.Open("postgres", connStr+" dbname=postgres")
if err != nil {
t.Fatalf("failed to open postgres conn with connstr=%s : %s", connStr, err)
}
_, err = db.Exec(fmt.Sprintf(`CREATE DATABASE %s;`, dbName))
if err != nil {
pqErr, ok := err.(*pq.Error)
if !ok {
t.Fatalf("failed to CREATE DATABASE: %s", err)
}
// we ignore duplicate database error as we expect this
if pqErr.Code != "42P04" {
t.Fatalf("failed to CREATE DATABASE with code=%s msg=%s", pqErr.Code, pqErr.Message)
}
}
_, err = db.Exec(fmt.Sprintf(`GRANT ALL PRIVILEGES ON DATABASE %s TO %s`, dbName, user))
if err != nil {
t.Fatalf("failed to GRANT: %s", err)
}
_ = db.Close()
} }
func currentUser() string { func currentUser() string {
@ -64,6 +90,7 @@ func currentUser() string {
// TODO: namespace for concurrent package tests // TODO: namespace for concurrent package tests
func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, close func()) { func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, close func()) {
if dbType == DBTypeSQLite { if dbType == DBTypeSQLite {
// this will be made in the current working directory which namespaces concurrent package runs correctly
dbname := "dendrite_test.db" dbname := "dendrite_test.db"
return fmt.Sprintf("file:%s", dbname), func() { return fmt.Sprintf("file:%s", dbname), func() {
err := os.Remove(dbname) err := os.Remove(dbname)
@ -79,13 +106,9 @@ func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, clo
if user == "" { if user == "" {
user = currentUser() user = currentUser()
} }
dbName := os.Getenv("POSTGRES_DB")
if dbName == "" {
dbName = createLocalDB("dendrite_test")
}
connStr = fmt.Sprintf( connStr = fmt.Sprintf(
"user=%s dbname=%s sslmode=disable", "user=%s sslmode=disable",
user, dbName, user,
) )
// optional vars, used in CI // optional vars, used in CI
password := os.Getenv("POSTGRES_PASSWORD") password := os.Getenv("POSTGRES_PASSWORD")
@ -97,6 +120,25 @@ func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, clo
connStr += fmt.Sprintf(" host=%s", host) connStr += fmt.Sprintf(" host=%s", host)
} }
// superuser database
postgresDB := os.Getenv("POSTGRES_DB")
// we cannot use 'dendrite_test' here else 2x concurrently running packages will try to use the same db.
// instead, hash the current working directory, snaffle the first 16 bytes and append that to dendrite_test
// and use that as the unique db name. We do this because packages are per-directory hence by hashing the
// working (test) directory we ensure we get a consistent hash and don't hash against concurrent packages.
wd, err := os.Getwd()
if err != nil {
t.Fatalf("cannot get working directory: %s", err)
}
hash := sha256.Sum256([]byte(wd))
dbName := fmt.Sprintf("dendrite_test_%s", hex.EncodeToString(hash[:16]))
if postgresDB == "" { // local server, use createdb
createLocalDB(dbName)
} else { // remote server, shell into the postgres user and CREATE DATABASE
createRemoteDB(t, dbName, user, connStr)
}
connStr += fmt.Sprintf(" dbname=%s", dbName)
return connStr, func() { return connStr, func() {
// Drop all tables on the database to get a fresh instance // Drop all tables on the database to get a fresh instance
db, err := sql.Open("postgres", connStr) db, err := sql.Open("postgres", connStr)
@ -121,6 +163,7 @@ func WithAllDatabases(t *testing.T, testFn func(t *testing.T, db DBType)) {
for dbName, dbType := range dbs { for dbName, dbType := range dbs {
dbt := dbType dbt := dbType
t.Run(dbName, func(tt *testing.T) { t.Run(dbName, func(tt *testing.T) {
tt.Parallel()
testFn(tt, dbt) testFn(tt, dbt)
}) })
} }

View file

@ -15,7 +15,9 @@
package test package test
import ( import (
"bytes"
"crypto/ed25519" "crypto/ed25519"
"testing"
"time" "time"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -49,3 +51,40 @@ func WithUnsigned(unsigned interface{}) eventModifier {
e.unsigned = unsigned e.unsigned = unsigned
} }
} }
// Reverse a list of events
func Reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
for i := 0; i < len(in); i++ {
out[i] = in[len(in)-i-1]
}
return out
}
func AssertEventIDsEqual(t *testing.T, gotEventIDs []string, wants []*gomatrixserverlib.HeaderedEvent) {
t.Helper()
if len(gotEventIDs) != len(wants) {
t.Fatalf("length mismatch: got %d events, want %d", len(gotEventIDs), len(wants))
}
for i := range wants {
w := wants[i].EventID()
g := gotEventIDs[i]
if w != g {
t.Errorf("event at index %d mismatch:\ngot %s\n\nwant %s", i, string(g), string(w))
}
}
}
func AssertEventsEqual(t *testing.T, gots, wants []*gomatrixserverlib.HeaderedEvent) {
t.Helper()
if len(gots) != len(wants) {
t.Fatalf("length mismatch: got %d events, want %d", len(gots), len(wants))
}
for i := range wants {
w := wants[i].JSON()
g := gots[i].JSON()
if !bytes.Equal(w, g) {
t.Errorf("event at index %d mismatch:\ngot %s\n\nwant %s", i, string(g), string(w))
}
}
}

View file

@ -492,16 +492,16 @@ type PerformPusherDeletionRequest struct {
// Pusher represents a push notification subscriber // Pusher represents a push notification subscriber
type Pusher struct { type Pusher struct {
SessionID int64 `json:"session_id,omitempty"` SessionID int64 `json:"session_id,omitempty"`
PushKey string `json:"pushkey"` PushKey string `json:"pushkey"`
PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"` PushKeyTS int64 `json:"pushkey_ts,omitempty"`
Kind PusherKind `json:"kind"` Kind PusherKind `json:"kind"`
AppID string `json:"app_id"` AppID string `json:"app_id"`
AppDisplayName string `json:"app_display_name"` AppDisplayName string `json:"app_display_name"`
DeviceDisplayName string `json:"device_display_name"` DeviceDisplayName string `json:"device_display_name"`
ProfileTag string `json:"profile_tag"` ProfileTag string `json:"profile_tag"`
Language string `json:"lang"` Language string `json:"lang"`
Data map[string]interface{} `json:"data"` Data map[string]interface{} `json:"data"`
} }
type PusherKind string type PusherKind string

View file

@ -653,7 +653,7 @@ func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.Perform
return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart) return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart)
} }
if req.Pusher.PushKeyTS == 0 { if req.Pusher.PushKeyTS == 0 {
req.Pusher.PushKeyTS = gomatrixserverlib.AsTimestamp(time.Now()) req.Pusher.PushKeyTS = int64(time.Now().Unix())
} }
return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart) return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart)
} }

View file

@ -23,7 +23,6 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -95,7 +94,7 @@ type pushersStatements struct {
// Returns nil error success. // Returns nil error success.
func (s *pushersStatements) InsertPusher( func (s *pushersStatements) InsertPusher(
ctx context.Context, txn *sql.Tx, session_id int64, ctx context.Context, txn *sql.Tx, session_id int64,
pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
logrus.Debugf("Created pusher %d", session_id) logrus.Debugf("Created pusher %d", session_id)

View file

@ -23,7 +23,6 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -95,7 +94,7 @@ type pushersStatements struct {
// Returns nil error success. // Returns nil error success.
func (s *pushersStatements) InsertPusher( func (s *pushersStatements) InsertPusher(
ctx context.Context, txn *sql.Tx, session_id int64, ctx context.Context, txn *sql.Tx, session_id int64,
pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
) error { ) error {
_, err := s.insertPusherStmt.ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) _, err := s.insertPusherStmt.ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
logrus.Debugf("Created pusher %d", session_id) logrus.Debugf("Created pusher %d", session_id)

View file

@ -21,7 +21,6 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
) )
type AccountDataTable interface { type AccountDataTable interface {
@ -96,7 +95,7 @@ type ThreePIDTable interface {
} }
type PusherTable interface { type PusherTable interface {
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error
SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error) SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error)
DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error
DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error