Merge branch 'main' into neilalexander/removelibp2p

This commit is contained in:
Neil Alexander 2022-04-21 11:22:46 +01:00
commit 786b89dcd5
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
68 changed files with 1397 additions and 977 deletions

View file

@ -52,6 +52,7 @@ import (
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
pineconeConnections "github.com/matrix-org/pinecone/connections"
pineconeMulticast "github.com/matrix-org/pinecone/multicast"
pineconeRouter "github.com/matrix-org/pinecone/router"
pineconeSessions "github.com/matrix-org/pinecone/sessions"
@ -71,11 +72,9 @@ type DendriteMonolith struct {
PineconeRouter *pineconeRouter.Router
PineconeMulticast *pineconeMulticast.Multicast
PineconeQUIC *pineconeSessions.Sessions
PineconeManager *pineconeConnections.ConnectionManager
StorageDirectory string
CacheDirectory string
staticPeerURI string
staticPeerMutex sync.RWMutex
staticPeerAttempt chan struct{}
listener net.Listener
httpServer *http.Server
processContext *process.ProcessContext
@ -104,15 +103,8 @@ func (m *DendriteMonolith) SetMulticastEnabled(enabled bool) {
}
func (m *DendriteMonolith) SetStaticPeer(uri string) {
m.staticPeerMutex.Lock()
m.staticPeerURI = strings.TrimSpace(uri)
m.staticPeerMutex.Unlock()
m.DisconnectType(int(pineconeRouter.PeerTypeRemote))
if uri != "" {
go func() {
m.staticPeerAttempt <- struct{}{}
}()
}
m.PineconeManager.RemovePeers()
m.PineconeManager.AddPeer(strings.TrimSpace(uri))
}
func (m *DendriteMonolith) DisconnectType(peertype int) {
@ -210,43 +202,6 @@ func (m *DendriteMonolith) RegisterDevice(localpart, deviceID string) (string, e
return loginRes.Device.AccessToken, nil
}
func (m *DendriteMonolith) staticPeerConnect() {
connected := map[string]bool{} // URI -> connected?
attempt := func() {
m.staticPeerMutex.RLock()
uri := m.staticPeerURI
m.staticPeerMutex.RUnlock()
if uri == "" {
return
}
for k := range connected {
delete(connected, k)
}
for _, uri := range strings.Split(uri, ",") {
connected[strings.TrimSpace(uri)] = false
}
for _, info := range m.PineconeRouter.Peers() {
connected[info.URI] = true
}
for k, online := range connected {
if !online {
if err := conn.ConnectToPeer(m.PineconeRouter, k); err != nil {
logrus.WithError(err).Error("Failed to connect to static peer")
}
}
}
}
for {
select {
case <-m.processContext.Context().Done():
case <-m.staticPeerAttempt:
attempt()
case <-time.After(time.Second * 5):
attempt()
}
}
}
// nolint:gocyclo
func (m *DendriteMonolith) Start() {
var err error
@ -284,6 +239,7 @@ func (m *DendriteMonolith) Start() {
m.PineconeRouter = pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
m.PineconeQUIC = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), m.PineconeRouter, []string{"matrix"})
m.PineconeMulticast = pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), m.PineconeRouter)
m.PineconeManager = pineconeConnections.NewConnectionManager(m.PineconeRouter)
prefix := hex.EncodeToString(pk)
cfg := &config.Dendrite{}
@ -392,9 +348,6 @@ func (m *DendriteMonolith) Start() {
m.processContext = base.ProcessContext
m.staticPeerAttempt = make(chan struct{}, 1)
go m.staticPeerConnect()
go func() {
m.logger.Info("Listening on ", cfg.Global.ServerName)
m.logger.Fatal(m.httpServer.Serve(m.PineconeQUIC.Protocol("matrix")))

View file

@ -52,6 +52,7 @@ func RequestTurnServer(req *http.Request, device *api.Device, cfg *config.Client
if turnConfig.SharedSecret != "" {
expiry := time.Now().Add(duration).Unix()
resp.Username = fmt.Sprintf("%d:%s", expiry, device.UserID)
mac := hmac.New(sha1.New, []byte(turnConfig.SharedSecret))
_, err := mac.Write([]byte(resp.Username))
@ -60,7 +61,6 @@ func RequestTurnServer(req *http.Request, device *api.Device, cfg *config.Client
return jsonerror.InternalServerError()
}
resp.Username = fmt.Sprintf("%d:%s", expiry, device.UserID)
resp.Password = base64.StdEncoding.EncodeToString(mac.Sum(nil))
} else if turnConfig.Username != "" && turnConfig.Password != "" {
resp.Username = turnConfig.Username

View file

@ -25,7 +25,6 @@ import (
"net"
"net/http"
"os"
"strings"
"time"
"github.com/gorilla/mux"
@ -47,6 +46,7 @@ import (
"github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/gomatrixserverlib"
pineconeConnections "github.com/matrix-org/pinecone/connections"
pineconeMulticast "github.com/matrix-org/pinecone/multicast"
pineconeRouter "github.com/matrix-org/pinecone/router"
pineconeSessions "github.com/matrix-org/pinecone/sessions"
@ -90,6 +90,13 @@ func main() {
}
pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter)
pManager := pineconeConnections.NewConnectionManager(pRouter)
pMulticast.Start()
if instancePeer != nil && *instancePeer != "" {
pManager.AddPeer(*instancePeer)
}
go func() {
listener, err := net.Listen("tcp", *instanceListen)
@ -119,36 +126,6 @@ func main() {
}
}()
pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter)
pMulticast.Start()
connectToStaticPeer := func() {
connected := map[string]bool{} // URI -> connected?
for _, uri := range strings.Split(*instancePeer, ",") {
connected[strings.TrimSpace(uri)] = false
}
attempt := func() {
for k := range connected {
connected[k] = false
}
for _, info := range pRouter.Peers() {
connected[info.URI] = true
}
for k, online := range connected {
if !online {
if err := conn.ConnectToPeer(pRouter, k); err != nil {
logrus.WithError(err).Error("Failed to connect to static peer")
}
}
}
}
for {
attempt()
time.Sleep(time.Second * 5)
}
}
cfg := &config.Dendrite{}
cfg.Defaults(true)
cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk))
@ -268,7 +245,6 @@ func main() {
Handler: pMux,
}
go connectToStaticPeer()
go func() {
pubkey := pRouter.PublicKey()
logrus.Info("Listening on ", hex.EncodeToString(pubkey[:]))

View file

@ -22,7 +22,6 @@ import (
"encoding/hex"
"fmt"
"syscall/js"
"time"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/appservice"
@ -44,6 +43,7 @@ import (
_ "github.com/matrix-org/go-sqlite3-js"
pineconeConnections "github.com/matrix-org/pinecone/connections"
pineconeRouter "github.com/matrix-org/pinecone/router"
pineconeSessions "github.com/matrix-org/pinecone/sessions"
)
@ -154,6 +154,8 @@ func startup() {
pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
pSessions := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
pManager := pineconeConnections.NewConnectionManager(pRouter)
pManager.AddPeer("wss://pinecone.matrix.org/public")
cfg := &config.Dendrite{}
cfg.Defaults(true)
@ -237,20 +239,4 @@ func startup() {
}
s.ListenAndServe("fetch")
}()
// Connect to the static peer
go func() {
for {
if pRouter.PeerCount(pineconeRouter.PeerTypeRemote) == 0 {
if err := conn.ConnectToPeer(pRouter, publicPeer); err != nil {
logrus.WithError(err).Error("Failed to connect to static peer")
}
}
select {
case <-base.ProcessContext.Context().Done():
return
case <-time.After(time.Second * 5):
}
}
}()
}

13
go.mod
View file

@ -1,8 +1,8 @@
module github.com/matrix-org/dendrite
replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.7.5-0.20220311134712-e2e4a244f30e
replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.8.1-0.20220419100629-2278c94774f9
replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c
replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.13.1-0.20220419101051-b262d9f0be1e
require (
github.com/Arceliar/ironwood v0.0.0-20211125050254-8951369625d0
@ -27,15 +27,15 @@ require (
github.com/juju/testing v0.0.0-20220203020004-a0ff61f03494 // indirect
github.com/lib/pq v1.10.5
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5
github.com/matrix-org/pinecone v0.0.0-20220404141326-e526fa82f79d
github.com/matrix-org/gomatrixserverlib v0.0.0-20220408160933-cf558306b56f
github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.10
github.com/miekg/dns v1.1.31 // indirect
github.com/nats-io/nats-server/v2 v2.7.4-0.20220309205833-773636c1c5bb
github.com/nats-io/nats.go v1.13.1-0.20220308171302-2f2f6968e98d
github.com/nats-io/nats.go v1.14.0
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
github.com/ngrok/sqlmw v0.0.0-20211220175533-9d16fdc47b31
@ -56,6 +56,7 @@ require (
golang.org/x/image v0.0.0-20220321031419-a8550c1d254a
golang.org/x/mobile v0.0.0-20220407111146-e579adbbc4a2
golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3
golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 // indirect
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211
gopkg.in/h2non/bimg.v1 v1.1.9
gopkg.in/yaml.v2 v2.4.0

27
go.sum
View file

@ -789,15 +789,15 @@ github.com/masterzen/winrm v0.0.0-20161014151040-7a535cd943fc/go.mod h1:CfZSN7zw
github.com/masterzen/xmlpath v0.0.0-20140218185901-13f4951698ad/go.mod h1:A0zPC53iKKKcXYxr4ROjpQRQ5FgJXtelNdSmHHuq/tY=
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e h1:DP5RC0Z3XdyBEW5dKt8YPeN6vZbm6OzVaGVp7f1BQRM=
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e/go.mod h1:NgPCr+UavRGH6n5jmdX8DuqFZ4JiCWIJoZiuhTRLSUg=
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d h1:mGhPVaTht5NViFN/UpdrIlRApmH2FWcVaKUH5MdBKiY=
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw2QV3YD/fRrzEDPNGgTlJlvXY0EHHnT87wF3OA=
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5 h1:Fkennny7+Z/5pygrhjFMZbz1j++P2hhhLoT7NO3p8DQ=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5/go.mod h1:V5eO8rn/C3rcxig37A/BCeKerLFS+9Avg/77FIeTZ48=
github.com/matrix-org/pinecone v0.0.0-20220404141326-e526fa82f79d h1:1+T4eOPRsf6cr0lMPW4oO2k8TTHm4mqIh65kpEID5Rk=
github.com/matrix-org/pinecone v0.0.0-20220404141326-e526fa82f79d/go.mod h1:ulJzsVOTssIVp1j/m5eI//4VpAGDkMt5NrRuAVX7wpc=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220408160933-cf558306b56f h1:MZrl4TgTnlaOn2Cu9gJCoJ3oyW5mT4/3QIZGgZXzKl4=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220408160933-cf558306b56f/go.mod h1:V5eO8rn/C3rcxig37A/BCeKerLFS+9Avg/77FIeTZ48=
github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48 h1:W0sjjC6yjskHX4mb0nk3p0fXAlbU5bAFUFeEtlrPASE=
github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48/go.mod h1:ulJzsVOTssIVp1j/m5eI//4VpAGDkMt5NrRuAVX7wpc=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
@ -878,8 +878,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8m
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
github.com/nats-io/jwt/v2 v2.2.1-0.20220113022732-58e87895b296 h1:vU9tpM3apjYlLLeY23zRWJ9Zktr5jp+mloR942LEOpY=
github.com/nats-io/jwt/v2 v2.2.1-0.20220113022732-58e87895b296/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k=
github.com/nats-io/jwt/v2 v2.2.1-0.20220330180145-442af02fd36a h1:lem6QCvxR0Y28gth9P+wV2K/zYUUAkJ+55U8cpS0p5I=
github.com/nats-io/jwt/v2 v2.2.1-0.20220330180145-442af02fd36a/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k=
github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8=
github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4=
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
@ -888,10 +888,10 @@ github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uY
github.com/ncw/swift v1.0.47/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM=
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo=
github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
github.com/neilalexander/nats-server/v2 v2.7.5-0.20220311134712-e2e4a244f30e h1:5tEHLzvDeS6IeqO2o9FFhsE3V2erYj8FlMt2J91wzsk=
github.com/neilalexander/nats-server/v2 v2.7.5-0.20220311134712-e2e4a244f30e/go.mod h1:1vZ2Nijh8tcyNe8BDVyTviCd9NYzRbubQYiEHsvOQWc=
github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c h1:G2qsv7D0rY94HAu8pXmElMluuMHQ85waxIDQBhIzV2Q=
github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w=
github.com/neilalexander/nats-server/v2 v2.8.1-0.20220419100629-2278c94774f9 h1:VGU5HYAwy8LRbSkrT+kCHvujVmwK8Aa/vc1O+eReTbM=
github.com/neilalexander/nats-server/v2 v2.8.1-0.20220419100629-2278c94774f9/go.mod h1:5vic7C58BFEVltiZhs7Kq81q2WcEPhJPsmNv1FOrdv0=
github.com/neilalexander/nats.go v1.13.1-0.20220419101051-b262d9f0be1e h1:kNIzIzj2OvnlreA+sTJ12nWJzTP3OSLNKDL/Iq9mF6Y=
github.com/neilalexander/nats.go v1.13.1-0.20220419101051-b262d9f0be1e/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w=
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 h1:lrVQzBtkeQEGGYUHwSX1XPe1E5GL6U3KYCNe2G4bncQ=
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
@ -1541,8 +1541,9 @@ golang.org/x/sys v0.0.0-20211116061358-0a5406a5449c/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64 h1:D1v9ucDTYBtbz5vNuBbAhIMAGhQhJ6Ym5ah3maMVNX4=
golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 h1:QyVthZKMsyaQwBTJE04jdNN0Pp5Fn9Qga0mrgxyERQM=
golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=

View file

@ -0,0 +1,86 @@
package caching
import (
"fmt"
"time"
userapi "github.com/matrix-org/dendrite/userapi/api"
)
const (
LazyLoadCacheName = "lazy_load_members"
LazyLoadCacheMaxEntries = 128
LazyLoadCacheMaxUserEntries = 128
LazyLoadCacheMutable = true
LazyLoadCacheMaxAge = time.Minute * 30
)
type LazyLoadCache struct {
// InMemoryLRUCachePartition containing other InMemoryLRUCachePartitions
// with the actual cached members
userCaches *InMemoryLRUCachePartition
}
// NewLazyLoadCache creates a new LazyLoadCache.
func NewLazyLoadCache() (*LazyLoadCache, error) {
cache, err := NewInMemoryLRUCachePartition(
LazyLoadCacheName,
LazyLoadCacheMutable,
LazyLoadCacheMaxEntries,
LazyLoadCacheMaxAge,
true,
)
if err != nil {
return nil, err
}
go cacheCleaner(cache)
return &LazyLoadCache{
userCaches: cache,
}, nil
}
func (c *LazyLoadCache) lazyLoadCacheForUser(device *userapi.Device) (*InMemoryLRUCachePartition, error) {
cacheName := fmt.Sprintf("%s/%s", device.UserID, device.ID)
userCache, ok := c.userCaches.Get(cacheName)
if ok && userCache != nil {
if cache, ok := userCache.(*InMemoryLRUCachePartition); ok {
return cache, nil
}
}
cache, err := NewInMemoryLRUCachePartition(
LazyLoadCacheName,
LazyLoadCacheMutable,
LazyLoadCacheMaxUserEntries,
LazyLoadCacheMaxAge,
false,
)
if err != nil {
return nil, err
}
c.userCaches.Set(cacheName, cache)
go cacheCleaner(cache)
return cache, nil
}
func (c *LazyLoadCache) StoreLazyLoadedUser(device *userapi.Device, roomID, userID, eventID string) {
cache, err := c.lazyLoadCacheForUser(device)
if err != nil {
return
}
cacheKey := fmt.Sprintf("%s/%s/%s/%s", device.UserID, device.ID, roomID, userID)
cache.Set(cacheKey, eventID)
}
func (c *LazyLoadCache) IsLazyLoadedUserCached(device *userapi.Device, roomID, userID string) (string, bool) {
cache, err := c.lazyLoadCacheForUser(device)
if err != nil {
return "", false
}
cacheKey := fmt.Sprintf("%s/%s/%s/%s", device.UserID, device.ID, roomID, userID)
val, ok := cache.Get(cacheKey)
if !ok {
return "", ok
}
return val.(string), ok
}

View file

@ -3,8 +3,6 @@ package pushgateway
import (
"context"
"encoding/json"
"github.com/matrix-org/gomatrixserverlib"
)
// A Client is how interactions with a Push Gateway is done.
@ -50,7 +48,7 @@ type Device struct {
AppID string `json:"app_id"` // Required
Data map[string]interface{} `json:"data"` // Required. UNSPEC: Sytests require this to allow unknown keys.
PushKey string `json:"pushkey"` // Required
PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"`
PushKeyTS int64 `json:"pushkey_ts,omitempty"`
Tweaks map[string]interface{} `json:"tweaks,omitempty"`
}

View file

@ -32,7 +32,7 @@ func AddPublicRoutes(
userAPI userapi.UserInternalAPI,
client *gomatrixserverlib.Client,
) {
mediaDB, err := storage.Open(&cfg.Database)
mediaDB, err := storage.NewMediaAPIDatasource(&cfg.Database)
if err != nil {
logrus.WithError(err).Panicf("failed to connect to media db")
}

View file

@ -22,6 +22,7 @@ import (
"io"
"net/http"
"net/url"
"os"
"path"
"strings"
@ -311,6 +312,26 @@ func (r *uploadRequest) storeFileAndMetadata(
}
go func() {
file, err := os.Open(string(finalPath))
if err != nil {
r.Logger.WithError(err).Error("unable to open file")
return
}
defer file.Close() // nolint: errcheck
// http.DetectContentType only needs 512 bytes
buf := make([]byte, 512)
_, err = file.Read(buf)
if err != nil {
r.Logger.WithError(err).Error("unable to read file")
return
}
// Check if we need to generate thumbnails
fileType := http.DetectContentType(buf)
if !strings.HasPrefix(fileType, "image") {
r.Logger.WithField("contentType", fileType).Debugf("uploaded file is not an image or can not be thumbnailed, not generating thumbnails")
return
}
busy, err := thumbnailer.GenerateThumbnails(
context.Background(), finalPath, thumbnailSizes, r.MediaMetadata,
activeThumbnailGeneration, maxThumbnailGenerators, db, r.Logger,

View file

@ -51,7 +51,7 @@ func Test_uploadRequest_doUpload(t *testing.T) {
_ = os.Mkdir(testdataPath, os.ModePerm)
defer fileutils.RemoveDir(types.Path(testdataPath), nil)
db, err := storage.Open(&config.DatabaseOptions{
db, err := storage.NewMediaAPIDatasource(&config.DatabaseOptions{
ConnectionString: "file::memory:?cache=shared",
MaxOpenConnections: 100,
MaxIdleConnections: 2,

View file

@ -22,9 +22,17 @@ import (
)
type Database interface {
MediaRepository
Thumbnails
}
type MediaRepository interface {
StoreMediaMetadata(ctx context.Context, mediaMetadata *types.MediaMetadata) error
GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error)
GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error)
}
type Thumbnails interface {
StoreThumbnail(ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata) error
GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error)
GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) ([]*types.ThumbnailMetadata, error)

View file

@ -20,6 +20,8 @@ import (
"database/sql"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
@ -69,24 +71,25 @@ type mediaStatements struct {
selectMediaByHashStmt *sql.Stmt
}
func (s *mediaStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(mediaSchema)
func NewPostgresMediaRepositoryTable(db *sql.DB) (tables.MediaRepository, error) {
s := &mediaStatements{}
_, err := db.Exec(mediaSchema)
if err != nil {
return
return nil, err
}
return statementList{
return s, sqlutil.StatementList{
{&s.insertMediaStmt, insertMediaSQL},
{&s.selectMediaStmt, selectMediaSQL},
{&s.selectMediaByHashStmt, selectMediaByHashSQL},
}.prepare(db)
}.Prepare(db)
}
func (s *mediaStatements) insertMedia(
ctx context.Context, mediaMetadata *types.MediaMetadata,
func (s *mediaStatements) InsertMedia(
ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata,
) error {
mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
_, err := s.insertMediaStmt.ExecContext(
mediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
_, err := sqlutil.TxStmtContext(ctx, txn, s.insertMediaStmt).ExecContext(
ctx,
mediaMetadata.MediaID,
mediaMetadata.Origin,
@ -100,14 +103,14 @@ func (s *mediaStatements) insertMedia(
return err
}
func (s *mediaStatements) selectMedia(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
func (s *mediaStatements) SelectMedia(
ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata := types.MediaMetadata{
MediaID: mediaID,
Origin: mediaOrigin,
}
err := s.selectMediaStmt.QueryRowContext(
err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaStmt).QueryRowContext(
ctx, mediaMetadata.MediaID, mediaMetadata.Origin,
).Scan(
&mediaMetadata.ContentType,
@ -120,14 +123,14 @@ func (s *mediaStatements) selectMedia(
return &mediaMetadata, err
}
func (s *mediaStatements) selectMediaByHash(
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
func (s *mediaStatements) SelectMediaByHash(
ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata := types.MediaMetadata{
Base64Hash: mediaHash,
Origin: mediaOrigin,
}
err := s.selectMediaStmt.QueryRowContext(
err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaByHashStmt).QueryRowContext(
ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin,
).Scan(
&mediaMetadata.ContentType,

View file

@ -0,0 +1,46 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
// Import the postgres database driver.
_ "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/shared"
"github.com/matrix-org/dendrite/setup/config"
)
// NewDatabase opens a postgres database.
func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
mediaRepo, err := NewPostgresMediaRepositoryTable(db)
if err != nil {
return nil, err
}
thumbnails, err := NewPostgresThumbnailsTable(db)
if err != nil {
return nil, err
}
return &shared.Database{
MediaRepository: mediaRepo,
Thumbnails: thumbnails,
DB: db,
Writer: sqlutil.NewExclusiveWriter(),
}, nil
}

View file

@ -1,38 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// FIXME: This should be made internal!
package postgres
import (
"database/sql"
)
// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement.
type statementList []struct {
statement **sql.Stmt
sql string
}
// prepare the SQL for each statement in the list and assign the result to the prepared statement.
func (s statementList) prepare(db *sql.DB) (err error) {
for _, statement := range s {
if *statement.statement, err = db.Prepare(statement.sql); err != nil {
return
}
}
return
}

View file

@ -1,36 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"database/sql"
)
type statements struct {
media mediaStatements
thumbnail thumbnailStatements
}
func (s *statements) prepare(db *sql.DB) (err error) {
if err = s.media.prepare(db); err != nil {
return
}
if err = s.thumbnail.prepare(db); err != nil {
return
}
return
}

View file

@ -21,6 +21,8 @@ import (
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
@ -63,7 +65,7 @@ SELECT content_type, file_size_bytes, creation_ts FROM mediaapi_thumbnail WHERE
// Note: this selects all thumbnails for a media_origin and media_id
const selectThumbnailsSQL = `
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 ORDER BY creation_ts ASC
`
type thumbnailStatements struct {
@ -72,24 +74,25 @@ type thumbnailStatements struct {
selectThumbnailsStmt *sql.Stmt
}
func (s *thumbnailStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(thumbnailSchema)
func NewPostgresThumbnailsTable(db *sql.DB) (tables.Thumbnails, error) {
s := &thumbnailStatements{}
_, err := db.Exec(thumbnailSchema)
if err != nil {
return
return nil, err
}
return statementList{
return s, sqlutil.StatementList{
{&s.insertThumbnailStmt, insertThumbnailSQL},
{&s.selectThumbnailStmt, selectThumbnailSQL},
{&s.selectThumbnailsStmt, selectThumbnailsSQL},
}.prepare(db)
}.Prepare(db)
}
func (s *thumbnailStatements) insertThumbnail(
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
func (s *thumbnailStatements) InsertThumbnail(
ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata,
) error {
thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
_, err := s.insertThumbnailStmt.ExecContext(
thumbnailMetadata.MediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
_, err := sqlutil.TxStmtContext(ctx, txn, s.insertThumbnailStmt).ExecContext(
ctx,
thumbnailMetadata.MediaMetadata.MediaID,
thumbnailMetadata.MediaMetadata.Origin,
@ -103,8 +106,9 @@ func (s *thumbnailStatements) insertThumbnail(
return err
}
func (s *thumbnailStatements) selectThumbnail(
func (s *thumbnailStatements) SelectThumbnail(
ctx context.Context,
txn *sql.Tx,
mediaID types.MediaID,
mediaOrigin gomatrixserverlib.ServerName,
width, height int,
@ -121,7 +125,7 @@ func (s *thumbnailStatements) selectThumbnail(
ResizeMethod: resizeMethod,
},
}
err := s.selectThumbnailStmt.QueryRowContext(
err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailStmt).QueryRowContext(
ctx,
thumbnailMetadata.MediaMetadata.MediaID,
thumbnailMetadata.MediaMetadata.Origin,
@ -136,10 +140,10 @@ func (s *thumbnailStatements) selectThumbnail(
return &thumbnailMetadata, err
}
func (s *thumbnailStatements) selectThumbnails(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
func (s *thumbnailStatements) SelectThumbnails(
ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) ([]*types.ThumbnailMetadata, error) {
rows, err := s.selectThumbnailsStmt.QueryContext(
rows, err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailsStmt).QueryContext(
ctx, mediaID, mediaOrigin,
)
if err != nil {

View file

@ -1,5 +1,4 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -13,54 +12,38 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
package shared
import (
"context"
"database/sql"
// Import the postgres database driver.
_ "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
// Database is used to store metadata about a repository of media files.
type Database struct {
statements statements
db *sql.DB
}
// Open opens a postgres database.
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
var d Database
var err error
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
if err = d.statements.prepare(d.db); err != nil {
return nil, err
}
return &d, nil
DB *sql.DB
Writer sqlutil.Writer
MediaRepository tables.MediaRepository
Thumbnails tables.Thumbnails
}
// StoreMediaMetadata inserts the metadata about the uploaded media into the database.
// Returns an error if the combination of MediaID and Origin are not unique in the table.
func (d *Database) StoreMediaMetadata(
ctx context.Context, mediaMetadata *types.MediaMetadata,
) error {
return d.statements.media.insertMedia(ctx, mediaMetadata)
func (d Database) StoreMediaMetadata(ctx context.Context, mediaMetadata *types.MediaMetadata) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.MediaRepository.InsertMedia(ctx, txn, mediaMetadata)
})
}
// GetMediaMetadata returns metadata about media stored on this server.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there is no metadata associated with this media.
func (d *Database) GetMediaMetadata(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata, err := d.statements.media.selectMedia(ctx, mediaID, mediaOrigin)
func (d Database) GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) {
mediaMetadata, err := d.MediaRepository.SelectMedia(ctx, nil, mediaID, mediaOrigin)
if err != nil && err == sql.ErrNoRows {
return nil, nil
}
@ -70,10 +53,8 @@ func (d *Database) GetMediaMetadata(
// GetMediaMetadataByHash returns metadata about media stored on this server.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there is no metadata associated with this media.
func (d *Database) GetMediaMetadataByHash(
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata, err := d.statements.media.selectMediaByHash(ctx, mediaHash, mediaOrigin)
func (d Database) GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) {
mediaMetadata, err := d.MediaRepository.SelectMediaByHash(ctx, nil, mediaHash, mediaOrigin)
if err != nil && err == sql.ErrNoRows {
return nil, nil
}
@ -82,40 +63,36 @@ func (d *Database) GetMediaMetadataByHash(
// StoreThumbnail inserts the metadata about the thumbnail into the database.
// Returns an error if the combination of MediaID and Origin are not unique in the table.
func (d *Database) StoreThumbnail(
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
) error {
return d.statements.thumbnail.insertThumbnail(ctx, thumbnailMetadata)
func (d Database) StoreThumbnail(ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Thumbnails.InsertThumbnail(ctx, txn, thumbnailMetadata)
})
}
// GetThumbnail returns metadata about a specific thumbnail.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there is no metadata associated with this thumbnail.
func (d *Database) GetThumbnail(
ctx context.Context,
mediaID types.MediaID,
mediaOrigin gomatrixserverlib.ServerName,
width, height int,
resizeMethod string,
) (*types.ThumbnailMetadata, error) {
thumbnailMetadata, err := d.statements.thumbnail.selectThumbnail(
ctx, mediaID, mediaOrigin, width, height, resizeMethod,
)
if err != nil && err == sql.ErrNoRows {
func (d Database) GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error) {
metadata, err := d.Thumbnails.SelectThumbnail(ctx, nil, mediaID, mediaOrigin, width, height, resizeMethod)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return thumbnailMetadata, err
return nil, err
}
return metadata, err
}
// GetThumbnails returns metadata about all thumbnails for a specific media stored on this server.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there are no thumbnails associated with this media.
func (d *Database) GetThumbnails(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) ([]*types.ThumbnailMetadata, error) {
thumbnails, err := d.statements.thumbnail.selectThumbnails(ctx, mediaID, mediaOrigin)
if err != nil && err == sql.ErrNoRows {
func (d Database) GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) ([]*types.ThumbnailMetadata, error) {
metadatas, err := d.Thumbnails.SelectThumbnails(ctx, nil, mediaID, mediaOrigin)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return thumbnails, err
return nil, err
}
return metadatas, err
}

View file

@ -21,6 +21,7 @@ import (
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
@ -66,35 +67,32 @@ SELECT content_type, file_size_bytes, creation_ts, upload_name, media_id, user_i
type mediaStatements struct {
db *sql.DB
writer sqlutil.Writer
insertMediaStmt *sql.Stmt
selectMediaStmt *sql.Stmt
selectMediaByHashStmt *sql.Stmt
}
func (s *mediaStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db
s.writer = writer
_, err = db.Exec(mediaSchema)
func NewSQLiteMediaRepositoryTable(db *sql.DB) (tables.MediaRepository, error) {
s := &mediaStatements{
db: db,
}
_, err := db.Exec(mediaSchema)
if err != nil {
return
return nil, err
}
return statementList{
return s, sqlutil.StatementList{
{&s.insertMediaStmt, insertMediaSQL},
{&s.selectMediaStmt, selectMediaSQL},
{&s.selectMediaByHashStmt, selectMediaByHashSQL},
}.prepare(db)
}.Prepare(db)
}
func (s *mediaStatements) insertMedia(
ctx context.Context, mediaMetadata *types.MediaMetadata,
func (s *mediaStatements) InsertMedia(
ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata,
) error {
mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertMediaStmt)
_, err := stmt.ExecContext(
mediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
_, err := sqlutil.TxStmtContext(ctx, txn, s.insertMediaStmt).ExecContext(
ctx,
mediaMetadata.MediaID,
mediaMetadata.Origin,
@ -106,17 +104,16 @@ func (s *mediaStatements) insertMedia(
mediaMetadata.UserID,
)
return err
})
}
func (s *mediaStatements) selectMedia(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
func (s *mediaStatements) SelectMedia(
ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata := types.MediaMetadata{
MediaID: mediaID,
Origin: mediaOrigin,
}
err := s.selectMediaStmt.QueryRowContext(
err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaStmt).QueryRowContext(
ctx, mediaMetadata.MediaID, mediaMetadata.Origin,
).Scan(
&mediaMetadata.ContentType,
@ -129,14 +126,14 @@ func (s *mediaStatements) selectMedia(
return &mediaMetadata, err
}
func (s *mediaStatements) selectMediaByHash(
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
func (s *mediaStatements) SelectMediaByHash(
ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata := types.MediaMetadata{
Base64Hash: mediaHash,
Origin: mediaOrigin,
}
err := s.selectMediaStmt.QueryRowContext(
err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaByHashStmt).QueryRowContext(
ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin,
).Scan(
&mediaMetadata.ContentType,

View file

@ -16,23 +16,30 @@
package sqlite3
import (
"database/sql"
// Import the postgres database driver.
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/shared"
"github.com/matrix-org/dendrite/setup/config"
)
type statements struct {
media mediaStatements
thumbnail thumbnailStatements
}
func (s *statements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
if err = s.media.prepare(db, writer); err != nil {
return
// NewDatabase opens a SQLIte database.
func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
if err = s.thumbnail.prepare(db, writer); err != nil {
return
mediaRepo, err := NewSQLiteMediaRepositoryTable(db)
if err != nil {
return nil, err
}
return
thumbnails, err := NewSQLiteThumbnailsTable(db)
if err != nil {
return nil, err
}
return &shared.Database{
MediaRepository: mediaRepo,
Thumbnails: thumbnails,
DB: db,
Writer: sqlutil.NewExclusiveWriter(),
}, nil
}

View file

@ -1,38 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// FIXME: This should be made internal!
package sqlite3
import (
"database/sql"
)
// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement.
type statementList []struct {
statement **sql.Stmt
sql string
}
// prepare the SQL for each statement in the list and assign the result to the prepared statement.
func (s statementList) prepare(db *sql.DB) (err error) {
for _, statement := range s {
if *statement.statement, err = db.Prepare(statement.sql); err != nil {
return
}
}
return
}

View file

@ -1,123 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
// Import the postgres database driver.
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
// Database is used to store metadata about a repository of media files.
type Database struct {
statements statements
db *sql.DB
writer sqlutil.Writer
}
// Open opens a postgres database.
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
d := Database{
writer: sqlutil.NewExclusiveWriter(),
}
var err error
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
if err = d.statements.prepare(d.db, d.writer); err != nil {
return nil, err
}
return &d, nil
}
// StoreMediaMetadata inserts the metadata about the uploaded media into the database.
// Returns an error if the combination of MediaID and Origin are not unique in the table.
func (d *Database) StoreMediaMetadata(
ctx context.Context, mediaMetadata *types.MediaMetadata,
) error {
return d.statements.media.insertMedia(ctx, mediaMetadata)
}
// GetMediaMetadata returns metadata about media stored on this server.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there is no metadata associated with this media.
func (d *Database) GetMediaMetadata(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata, err := d.statements.media.selectMedia(ctx, mediaID, mediaOrigin)
if err != nil && err == sql.ErrNoRows {
return nil, nil
}
return mediaMetadata, err
}
// GetMediaMetadataByHash returns metadata about media stored on this server.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there is no metadata associated with this media.
func (d *Database) GetMediaMetadataByHash(
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata, err := d.statements.media.selectMediaByHash(ctx, mediaHash, mediaOrigin)
if err != nil && err == sql.ErrNoRows {
return nil, nil
}
return mediaMetadata, err
}
// StoreThumbnail inserts the metadata about the thumbnail into the database.
// Returns an error if the combination of MediaID and Origin are not unique in the table.
func (d *Database) StoreThumbnail(
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
) error {
return d.statements.thumbnail.insertThumbnail(ctx, thumbnailMetadata)
}
// GetThumbnail returns metadata about a specific thumbnail.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there is no metadata associated with this thumbnail.
func (d *Database) GetThumbnail(
ctx context.Context,
mediaID types.MediaID,
mediaOrigin gomatrixserverlib.ServerName,
width, height int,
resizeMethod string,
) (*types.ThumbnailMetadata, error) {
thumbnailMetadata, err := d.statements.thumbnail.selectThumbnail(
ctx, mediaID, mediaOrigin, width, height, resizeMethod,
)
if err != nil && err == sql.ErrNoRows {
return nil, nil
}
return thumbnailMetadata, err
}
// GetThumbnails returns metadata about all thumbnails for a specific media stored on this server.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there are no thumbnails associated with this media.
func (d *Database) GetThumbnails(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) ([]*types.ThumbnailMetadata, error) {
thumbnails, err := d.statements.thumbnail.selectThumbnails(ctx, mediaID, mediaOrigin)
if err != nil && err == sql.ErrNoRows {
return nil, nil
}
return thumbnails, err
}

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
@ -54,39 +55,32 @@ SELECT content_type, file_size_bytes, creation_ts FROM mediaapi_thumbnail WHERE
// Note: this selects all thumbnails for a media_origin and media_id
const selectThumbnailsSQL = `
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 ORDER BY creation_ts ASC
`
type thumbnailStatements struct {
db *sql.DB
writer sqlutil.Writer
insertThumbnailStmt *sql.Stmt
selectThumbnailStmt *sql.Stmt
selectThumbnailsStmt *sql.Stmt
}
func (s *thumbnailStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
_, err = db.Exec(thumbnailSchema)
func NewSQLiteThumbnailsTable(db *sql.DB) (tables.Thumbnails, error) {
s := &thumbnailStatements{}
_, err := db.Exec(thumbnailSchema)
if err != nil {
return
return nil, err
}
s.db = db
s.writer = writer
return statementList{
return s, sqlutil.StatementList{
{&s.insertThumbnailStmt, insertThumbnailSQL},
{&s.selectThumbnailStmt, selectThumbnailSQL},
{&s.selectThumbnailsStmt, selectThumbnailsSQL},
}.prepare(db)
}.Prepare(db)
}
func (s *thumbnailStatements) insertThumbnail(
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
) error {
thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertThumbnailStmt)
_, err := stmt.ExecContext(
func (s *thumbnailStatements) InsertThumbnail(ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata) error {
thumbnailMetadata.MediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
_, err := sqlutil.TxStmtContext(ctx, txn, s.insertThumbnailStmt).ExecContext(
ctx,
thumbnailMetadata.MediaMetadata.MediaID,
thumbnailMetadata.MediaMetadata.Origin,
@ -98,11 +92,11 @@ func (s *thumbnailStatements) insertThumbnail(
thumbnailMetadata.ThumbnailSize.ResizeMethod,
)
return err
})
}
func (s *thumbnailStatements) selectThumbnail(
func (s *thumbnailStatements) SelectThumbnail(
ctx context.Context,
txn *sql.Tx,
mediaID types.MediaID,
mediaOrigin gomatrixserverlib.ServerName,
width, height int,
@ -119,7 +113,7 @@ func (s *thumbnailStatements) selectThumbnail(
ResizeMethod: resizeMethod,
},
}
err := s.selectThumbnailStmt.QueryRowContext(
err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailStmt).QueryRowContext(
ctx,
thumbnailMetadata.MediaMetadata.MediaID,
thumbnailMetadata.MediaMetadata.Origin,
@ -134,10 +128,11 @@ func (s *thumbnailStatements) selectThumbnail(
return &thumbnailMetadata, err
}
func (s *thumbnailStatements) selectThumbnails(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
func (s *thumbnailStatements) SelectThumbnails(
ctx context.Context, txn *sql.Tx, mediaID types.MediaID,
mediaOrigin gomatrixserverlib.ServerName,
) ([]*types.ThumbnailMetadata, error) {
rows, err := s.selectThumbnailsStmt.QueryContext(
rows, err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailsStmt).QueryContext(
ctx, mediaID, mediaOrigin,
)
if err != nil {

View file

@ -25,13 +25,13 @@ import (
"github.com/matrix-org/dendrite/setup/config"
)
// Open opens a postgres database.
func Open(dbProperties *config.DatabaseOptions) (Database, error) {
// NewMediaAPIDatasource opens a database connection.
func NewMediaAPIDatasource(dbProperties *config.DatabaseOptions) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.Open(dbProperties)
return sqlite3.NewDatabase(dbProperties)
case dbProperties.ConnectionString.IsPostgres():
return postgres.Open(dbProperties)
return postgres.NewDatabase(dbProperties)
default:
return nil, fmt.Errorf("unexpected database type")
}

View file

@ -0,0 +1,135 @@
package storage_test
import (
"context"
"reflect"
"testing"
"github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
)
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := storage.NewMediaAPIDatasource(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
})
if err != nil {
t.Fatalf("NewSyncServerDatasource returned %s", err)
}
return db, close
}
func TestMediaRepository(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
ctx := context.Background()
t.Run("can insert media & query media", func(t *testing.T) {
metadata := &types.MediaMetadata{
MediaID: "testing",
Origin: "localhost",
ContentType: "image/png",
FileSizeBytes: 10,
UploadName: "upload test",
Base64Hash: "dGVzdGluZw==",
UserID: "@alice:localhost",
}
if err := db.StoreMediaMetadata(ctx, metadata); err != nil {
t.Fatalf("unable to store media metadata: %v", err)
}
// query by media id
gotMetadata, err := db.GetMediaMetadata(ctx, metadata.MediaID, metadata.Origin)
if err != nil {
t.Fatalf("unable to query media metadata: %v", err)
}
if !reflect.DeepEqual(metadata, gotMetadata) {
t.Fatalf("expected metadata %+v, got %v", metadata, gotMetadata)
}
// query by media hash
gotMetadata, err = db.GetMediaMetadataByHash(ctx, metadata.Base64Hash, metadata.Origin)
if err != nil {
t.Fatalf("unable to query media metadata by hash: %v", err)
}
if !reflect.DeepEqual(metadata, gotMetadata) {
t.Fatalf("expected metadata %+v, got %v", metadata, gotMetadata)
}
})
})
}
func TestThumbnailsStorage(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
ctx := context.Background()
t.Run("can insert thumbnails & query media", func(t *testing.T) {
thumbnails := []*types.ThumbnailMetadata{
{
MediaMetadata: &types.MediaMetadata{
MediaID: "testing",
Origin: "localhost",
ContentType: "image/png",
FileSizeBytes: 6,
},
ThumbnailSize: types.ThumbnailSize{
Width: 5,
Height: 5,
ResizeMethod: types.Crop,
},
},
{
MediaMetadata: &types.MediaMetadata{
MediaID: "testing",
Origin: "localhost",
ContentType: "image/png",
FileSizeBytes: 7,
},
ThumbnailSize: types.ThumbnailSize{
Width: 1,
Height: 1,
ResizeMethod: types.Scale,
},
},
}
for i := range thumbnails {
if err := db.StoreThumbnail(ctx, thumbnails[i]); err != nil {
t.Fatalf("unable to store thumbnail metadata: %v", err)
}
}
// query by single thumbnail
gotMetadata, err := db.GetThumbnail(ctx,
thumbnails[0].MediaMetadata.MediaID,
thumbnails[0].MediaMetadata.Origin,
thumbnails[0].ThumbnailSize.Width, thumbnails[0].ThumbnailSize.Height,
thumbnails[0].ThumbnailSize.ResizeMethod,
)
if err != nil {
t.Fatalf("unable to query thumbnail metadata: %v", err)
}
if !reflect.DeepEqual(thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata) {
t.Fatalf("expected metadata %+v, got %+v", thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata)
}
if !reflect.DeepEqual(thumbnails[0].ThumbnailSize, gotMetadata.ThumbnailSize) {
t.Fatalf("expected metadata %+v, got %+v", thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata)
}
// query by all thumbnails
gotMediadatas, err := db.GetThumbnails(ctx, thumbnails[0].MediaMetadata.MediaID, thumbnails[0].MediaMetadata.Origin)
if err != nil {
t.Fatalf("unable to query media metadata by hash: %v", err)
}
if len(gotMediadatas) != len(thumbnails) {
t.Fatalf("expected %d stored thumbnail metadata, got %d", len(thumbnails), len(gotMediadatas))
}
for i := range gotMediadatas {
if !reflect.DeepEqual(thumbnails[i].MediaMetadata, gotMediadatas[i].MediaMetadata) {
t.Fatalf("expected metadata %+v, got %v", thumbnails[i].MediaMetadata, gotMediadatas[i].MediaMetadata)
}
if !reflect.DeepEqual(thumbnails[i].ThumbnailSize, gotMediadatas[i].ThumbnailSize) {
t.Fatalf("expected metadata %+v, got %v", thumbnails[i].ThumbnailSize, gotMediadatas[i].ThumbnailSize)
}
}
})
})
}

View file

@ -22,10 +22,10 @@ import (
)
// Open opens a postgres database.
func Open(dbProperties *config.DatabaseOptions) (Database, error) {
func NewMediaAPIDatasource(dbProperties *config.DatabaseOptions) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.Open(dbProperties)
return sqlite3.NewDatabase(dbProperties)
case dbProperties.ConnectionString.IsPostgres():
return nil, fmt.Errorf("can't use Postgres implementation")
default:

View file

@ -0,0 +1,46 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tables
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
type Thumbnails interface {
InsertThumbnail(ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata) error
SelectThumbnail(
ctx context.Context, txn *sql.Tx,
mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
width, height int,
resizeMethod string,
) (*types.ThumbnailMetadata, error)
SelectThumbnails(
ctx context.Context, txn *sql.Tx, mediaID types.MediaID,
mediaOrigin gomatrixserverlib.ServerName,
) ([]*types.ThumbnailMetadata, error)
}
type MediaRepository interface {
InsertMedia(ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata) error
SelectMedia(ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error)
SelectMediaByHash(
ctx context.Context, txn *sql.Tx,
mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error)
}

View file

@ -45,16 +45,13 @@ type RequestMethod string
// MatrixUserID is a Matrix user ID string in the form @user:domain e.g. @alice:matrix.org
type MatrixUserID string
// UnixMs is the milliseconds since the Unix epoch
type UnixMs int64
// MediaMetadata is metadata associated with a media file
type MediaMetadata struct {
MediaID MediaID
Origin gomatrixserverlib.ServerName
ContentType ContentType
FileSizeBytes FileSizeBytes
CreationTimestamp UnixMs
CreationTimestamp gomatrixserverlib.Timestamp
UploadName Filename
Base64Hash Base64Hash
UserID MatrixUserID

View file

@ -36,7 +36,7 @@ import (
type Notifier struct {
lock *sync.RWMutex
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
roomIDToJoinedUsers map[string]userIDSet
roomIDToJoinedUsers map[string]*userIDSet
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
roomIDToPeekingDevices map[string]peekingDeviceSet
// The latest sync position
@ -54,7 +54,7 @@ type Notifier struct {
// the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase).
func NewNotifier() *Notifier {
return &Notifier{
roomIDToJoinedUsers: make(map[string]userIDSet),
roomIDToJoinedUsers: make(map[string]*userIDSet),
roomIDToPeekingDevices: make(map[string]peekingDeviceSet),
userDeviceStreams: make(map[string]map[string]*UserDeviceStream),
lock: &sync.RWMutex{},
@ -262,7 +262,7 @@ func (n *Notifier) SharedUsers(userID string) []string {
func (n *Notifier) _sharedUsers(userID string) []string {
n._sharedUserMap[userID] = struct{}{}
for roomID, users := range n.roomIDToJoinedUsers {
if _, ok := users[userID]; !ok {
if ok := users.isIn(userID); !ok {
continue
}
for _, userID := range n._joinedUsers(roomID) {
@ -282,8 +282,11 @@ func (n *Notifier) IsSharedUser(userA, userB string) bool {
defer n.lock.RUnlock()
var okA, okB bool
for _, users := range n.roomIDToJoinedUsers {
_, okA = users[userA]
_, okB = users[userB]
okA = users.isIn(userA)
if !okA {
continue
}
okB = users.isIn(userB)
if okA && okB {
return true
}
@ -345,11 +348,12 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) {
// This is just the bulk form of addJoinedUser
for roomID, userIDs := range roomIDToUserIDs {
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
n.roomIDToJoinedUsers[roomID] = make(userIDSet, len(userIDs))
n.roomIDToJoinedUsers[roomID] = newUserIDSet(len(userIDs))
}
for _, userID := range userIDs {
n.roomIDToJoinedUsers[roomID].add(userID)
}
n.roomIDToJoinedUsers[roomID].precompute()
}
}
@ -440,16 +444,18 @@ func (n *Notifier) _fetchUserStreams(userID string) []*UserDeviceStream {
func (n *Notifier) _addJoinedUser(roomID, userID string) {
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
n.roomIDToJoinedUsers[roomID] = make(userIDSet)
n.roomIDToJoinedUsers[roomID] = newUserIDSet(8)
}
n.roomIDToJoinedUsers[roomID].add(userID)
n.roomIDToJoinedUsers[roomID].precompute()
}
func (n *Notifier) _removeJoinedUser(roomID, userID string) {
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
n.roomIDToJoinedUsers[roomID] = make(userIDSet)
n.roomIDToJoinedUsers[roomID] = newUserIDSet(8)
}
n.roomIDToJoinedUsers[roomID].remove(userID)
n.roomIDToJoinedUsers[roomID].precompute()
}
func (n *Notifier) JoinedUsers(roomID string) (userIDs []string) {
@ -521,19 +527,52 @@ func (n *Notifier) _removeEmptyUserStreams() {
}
// A string set, mainly existing for improving clarity of structs in this file.
type userIDSet map[string]struct{}
func (s userIDSet) add(str string) {
s[str] = struct{}{}
type userIDSet struct {
sync.Mutex
set map[string]struct{}
precomputed []string
}
func (s userIDSet) remove(str string) {
delete(s, str)
func newUserIDSet(cap int) *userIDSet {
return &userIDSet{
set: make(map[string]struct{}, cap),
precomputed: nil,
}
}
func (s userIDSet) values() (vals []string) {
vals = make([]string, 0, len(s))
for str := range s {
func (s *userIDSet) add(str string) {
s.Lock()
defer s.Unlock()
s.set[str] = struct{}{}
s.precomputed = s.precomputed[:0] // invalidate cache
}
func (s *userIDSet) remove(str string) {
s.Lock()
defer s.Unlock()
delete(s.set, str)
s.precomputed = s.precomputed[:0] // invalidate cache
}
func (s *userIDSet) precompute() {
s.Lock()
defer s.Unlock()
s.precomputed = s.values()
}
func (s *userIDSet) isIn(str string) bool {
s.Lock()
defer s.Unlock()
_, ok := s.set[str]
return ok
}
func (s *userIDSet) values() (vals []string) {
if len(s.precomputed) > 0 {
return s.precomputed // only return if not invalidated
}
vals = make([]string, 0, len(s.set))
for str := range s.set {
vals = append(vals, str)
}
return

View file

@ -60,7 +60,9 @@ func Context(
Headers: nil,
}
}
filter.Rooms = append(filter.Rooms, roomID)
if filter.Rooms != nil {
*filter.Rooms = append(*filter.Rooms, roomID)
}
ctx := req.Context()
membershipRes := roomserver.QueryMembershipForUserResponse{}

View file

@ -262,12 +262,8 @@ func (r *messagesReq) retrieveEvents() (
clientEvents []gomatrixserverlib.ClientEvent, start,
end types.TopologyToken, err error,
) {
eventFilter := r.filter
// Retrieve the events from the local database.
streamEvents, err := r.db.GetEventsInTopologicalRange(
r.ctx, r.from, r.to, r.roomID, eventFilter.Limit, r.backwardOrdering,
)
streamEvents, err := r.db.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering)
if err != nil {
err = fmt.Errorf("GetEventsInRange: %w", err)
return

View file

@ -104,8 +104,8 @@ type Database interface {
// DeletePeek deletes all peeks for a given room by a given user
// Returns an error if there was a problem communicating with the database.
DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error)
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit.
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error)
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last.
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error)
// EventPositionInTopology returns the depth and stream position of the given event.
EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
// BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events.

View file

@ -47,14 +47,10 @@ const selectBackwardExtremitiesForRoomSQL = "" +
const deleteBackwardExtremitySQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
const deleteBackwardExtremitiesForRoomSQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1"
type backwardExtremitiesStatements struct {
insertBackwardExtremityStmt *sql.Stmt
selectBackwardExtremitiesForRoomStmt *sql.Stmt
deleteBackwardExtremityStmt *sql.Stmt
deleteBackwardExtremitiesForRoomStmt *sql.Stmt
}
func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
@ -72,9 +68,6 @@ func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremiti
if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil {
return nil, err
}
if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil {
return nil, err
}
return s, nil
}
@ -113,10 +106,3 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
return
}
func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID)
return err
}

View file

@ -233,9 +233,10 @@ func (s *currentRoomStateStatements) SelectCurrentState(
excludeEventIDs []string,
) ([]*gomatrixserverlib.HeaderedEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt)
senders, notSenders := getSendersStateFilterFilter(stateFilter)
rows, err := stmt.QueryContext(ctx, roomID,
pq.StringArray(stateFilter.Senders),
pq.StringArray(stateFilter.NotSenders),
pq.StringArray(senders),
pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)),
stateFilter.ContainsURL,

View file

@ -16,21 +16,45 @@ package postgres
import (
"strings"
"github.com/matrix-org/gomatrixserverlib"
)
// filterConvertWildcardToSQL converts wildcards as defined in
// https://matrix.org/docs/spec/client_server/r0.3.0.html#post-matrix-client-r0-user-userid-filter
// to SQL wildcards that can be used with LIKE()
func filterConvertTypeWildcardToSQL(values []string) []string {
func filterConvertTypeWildcardToSQL(values *[]string) []string {
if values == nil {
// Return nil instead of []string{} so IS NULL can work correctly when
// the return value is passed into SQL queries
return nil
}
ret := make([]string, len(values))
for i := range values {
ret[i] = strings.Replace(values[i], "*", "%", -1)
v := *values
ret := make([]string, len(v))
for i := range v {
ret[i] = strings.Replace(v[i], "*", "%", -1)
}
return ret
}
// TODO: Replace when Dendrite uses Go 1.18
func getSendersRoomEventFilter(filter *gomatrixserverlib.RoomEventFilter) (senders []string, notSenders []string) {
if filter.Senders != nil {
senders = *filter.Senders
}
if filter.NotSenders != nil {
notSenders = *filter.NotSenders
}
return senders, notSenders
}
func getSendersStateFilterFilter(filter *gomatrixserverlib.StateFilter) (senders []string, notSenders []string) {
if filter.Senders != nil {
senders = *filter.Senders
}
if filter.NotSenders != nil {
notSenders = *filter.NotSenders
}
return senders, notSenders
}

View file

@ -56,12 +56,6 @@ const upsertMembershipSQL = "" +
" ON CONFLICT ON CONSTRAINT syncapi_memberships_unique" +
" DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6"
const selectMembershipSQL = "" +
"SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" +
" WHERE room_id = $1 AND user_id = $2 AND membership = ANY($3)" +
" ORDER BY stream_pos DESC" +
" LIMIT 1"
const selectMembershipCountSQL = "" +
"SELECT COUNT(*) FROM (" +
" SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" +
@ -69,7 +63,6 @@ const selectMembershipCountSQL = "" +
type membershipsStatements struct {
upsertMembershipStmt *sql.Stmt
selectMembershipStmt *sql.Stmt
selectMembershipCountStmt *sql.Stmt
}
@ -82,9 +75,6 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil {
return nil, err
}
if s.selectMembershipStmt, err = db.Prepare(selectMembershipSQL); err != nil {
return nil, err
}
if s.selectMembershipCountStmt, err = db.Prepare(selectMembershipCountSQL); err != nil {
return nil, err
}
@ -111,14 +101,6 @@ func (s *membershipsStatements) UpsertMembership(
return err
}
func (s *membershipsStatements) SelectMembership(
ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string,
) (eventID string, streamPos, topologyPos types.StreamPosition, err error) {
stmt := sqlutil.TxStmt(txn, s.selectMembershipStmt)
err = stmt.QueryRowContext(ctx, roomID, userID, memberships).Scan(&eventID, &streamPos, &topologyPos)
return
}
func (s *membershipsStatements) SelectMembershipCount(
ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition,
) (count int, err error) {

View file

@ -81,6 +81,15 @@ const insertEventSQL = "" +
const selectEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)"
const selectEventsWithFilterSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" +
" AND ( $2::text[] IS NULL OR sender = ANY($2) )" +
" AND ( $3::text[] IS NULL OR NOT(sender = ANY($3)) )" +
" AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" +
" AND ( $5::text[] IS NULL OR NOT(type LIKE ANY($5)) )" +
" AND ( $6::bool IS NULL OR contains_url = $6 )" +
" LIMIT $7"
const selectRecentEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" +
@ -153,6 +162,7 @@ const selectContextAfterEventSQL = "" +
type outputRoomEventsStatements struct {
insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt
selectEventsWitFilterStmt *sql.Stmt
selectMaxEventIDStmt *sql.Stmt
selectRecentEventsStmt *sql.Stmt
selectRecentEventsForSyncStmt *sql.Stmt
@ -174,6 +184,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
return s, sqlutil.StatementList{
{&s.insertEventStmt, insertEventSQL},
{&s.selectEventsStmt, selectEventsSQL},
{&s.selectEventsWitFilterStmt, selectEventsWithFilterSQL},
{&s.selectMaxEventIDStmt, selectMaxEventIDSQL},
{&s.selectRecentEventsStmt, selectRecentEventsSQL},
{&s.selectRecentEventsForSyncStmt, selectRecentEventsForSyncSQL},
@ -204,11 +215,11 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
stateFilter *gomatrixserverlib.StateFilter, roomIDs []string,
) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt)
senders, notSenders := getSendersStateFilterFilter(stateFilter)
rows, err := stmt.QueryContext(
ctx, r.Low(), r.High(), pq.StringArray(roomIDs),
pq.StringArray(stateFilter.Senders),
pq.StringArray(stateFilter.NotSenders),
pq.StringArray(senders),
pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)),
stateFilter.ContainsURL,
@ -310,7 +321,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
// Parse content as JSON and search for an "url" key
containsURL := false
var content map[string]interface{}
if json.Unmarshal(event.Content(), &content) != nil {
if json.Unmarshal(event.Content(), &content) == nil {
// Set containsURL to true if url is present
_, containsURL = content["url"]
}
@ -353,10 +364,11 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
} else {
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt)
}
senders, notSenders := getSendersRoomEventFilter(eventFilter)
rows, err := stmt.QueryContext(
ctx, roomID, r.Low(), r.High(),
pq.StringArray(eventFilter.Senders),
pq.StringArray(eventFilter.NotSenders),
pq.StringArray(senders),
pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)),
eventFilter.Limit+1,
@ -398,11 +410,12 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
ctx context.Context, txn *sql.Tx,
roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
) ([]types.StreamEvent, error) {
senders, notSenders := getSendersRoomEventFilter(eventFilter)
stmt := sqlutil.TxStmt(txn, s.selectEarlyEventsStmt)
rows, err := stmt.QueryContext(
ctx, roomID, r.Low(), r.High(),
pq.StringArray(eventFilter.Senders),
pq.StringArray(eventFilter.NotSenders),
pq.StringArray(senders),
pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)),
eventFilter.Limit,
@ -427,15 +440,52 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
// selectEvents returns the events for the given event IDs. If an event is
// missing from the database, it will be omitted.
func (s *outputRoomEventsStatements) SelectEvents(
ctx context.Context, txn *sql.Tx, eventIDs []string,
ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool,
) ([]types.StreamEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectEventsStmt)
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
var (
stmt *sql.Stmt
rows *sql.Rows
err error
)
if filter == nil {
stmt = sqlutil.TxStmt(txn, s.selectEventsStmt)
rows, err = stmt.QueryContext(ctx, pq.StringArray(eventIDs))
} else {
senders, notSenders := getSendersRoomEventFilter(filter)
stmt = sqlutil.TxStmt(txn, s.selectEventsWitFilterStmt)
rows, err = stmt.QueryContext(ctx,
pq.StringArray(eventIDs),
pq.StringArray(senders),
pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
filter.ContainsURL,
filter.Limit,
)
}
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
return rowsToStreamEvents(rows)
streamEvents, err := rowsToStreamEvents(rows)
if err != nil {
return nil, err
}
if preserveOrder {
eventMap := make(map[string]types.StreamEvent)
for _, ev := range streamEvents {
eventMap[ev.EventID()] = ev
}
var returnEvents []types.StreamEvent
for _, eventID := range eventIDs {
ev, ok := eventMap[eventID]
if ok {
returnEvents = append(returnEvents, ev)
}
}
return returnEvents, nil
}
return streamEvents, nil
}
func (s *outputRoomEventsStatements) DeleteEventsForRoom(
@ -462,10 +512,11 @@ func (s *outputRoomEventsStatements) SelectContextEvent(ctx context.Context, txn
func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter,
) (evts []*gomatrixserverlib.HeaderedEvent, err error) {
senders, notSenders := getSendersRoomEventFilter(filter)
rows, err := sqlutil.TxStmt(txn, s.selectContextBeforeEventStmt).QueryContext(
ctx, roomID, id, filter.Limit,
pq.StringArray(filter.Senders),
pq.StringArray(filter.NotSenders),
pq.StringArray(senders),
pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
)
@ -494,10 +545,11 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
func (s *outputRoomEventsStatements) SelectContextAfterEvent(
ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter,
) (lastID int, evts []*gomatrixserverlib.HeaderedEvent, err error) {
senders, notSenders := getSendersRoomEventFilter(filter)
rows, err := sqlutil.TxStmt(txn, s.selectContextAfterEventStmt).QueryContext(
ctx, roomID, id, filter.Limit,
pq.StringArray(filter.Senders),
pq.StringArray(filter.NotSenders),
pq.StringArray(senders),
pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
)

View file

@ -73,9 +73,6 @@ const selectMaxPositionInTopologySQL = "" +
"SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" +
") ORDER BY stream_position DESC LIMIT 1"
const deleteTopologyForRoomSQL = "" +
"DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1"
const selectStreamToTopologicalPositionAscSQL = "" +
"SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;"
@ -88,7 +85,6 @@ type outputRoomEventsTopologyStatements struct {
selectEventIDsInRangeDESCStmt *sql.Stmt
selectPositionInTopologyStmt *sql.Stmt
selectMaxPositionInTopologyStmt *sql.Stmt
deleteTopologyForRoomStmt *sql.Stmt
selectStreamToTopologicalPositionAscStmt *sql.Stmt
selectStreamToTopologicalPositionDescStmt *sql.Stmt
}
@ -114,9 +110,6 @@ func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) {
if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil {
return nil, err
}
if s.deleteTopologyForRoomStmt, err = db.Prepare(deleteTopologyForRoomSQL); err != nil {
return nil, err
}
if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil {
return nil, err
}
@ -148,9 +141,9 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
// is requested or not.
var stmt *sql.Stmt
if chronologicalOrder {
stmt = s.selectEventIDsInRangeASCStmt
stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeASCStmt)
} else {
stmt = s.selectEventIDsInRangeDESCStmt
stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeDESCStmt)
}
// Query the event IDs.
@ -203,10 +196,3 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
return
}
func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID)
return err
}

View file

@ -150,7 +150,7 @@ func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, stre
// Returns an error if there was a problem talking with the database.
// Does not include any transaction IDs in the returned events.
func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs)
streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, nil, false)
if err != nil {
return nil, err
}
@ -312,7 +312,7 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e
// Check if we have all of the event's previous events. If an event is
// missing, add it to the room's backward extremities.
prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs())
prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), nil, false)
if err != nil {
return err
}
@ -429,7 +429,8 @@ func (d *Database) updateRoomState(
func (d *Database) GetEventsInTopologicalRange(
ctx context.Context,
from, to *types.TopologyToken,
roomID string, limit int,
roomID string,
filter *gomatrixserverlib.RoomEventFilter,
backwardOrdering bool,
) (events []types.StreamEvent, err error) {
var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition
@ -450,14 +451,14 @@ func (d *Database) GetEventsInTopologicalRange(
// Select the event IDs from the defined range.
var eIDs []string
eIDs, err = d.Topology.SelectEventIDsInRange(
ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, limit, !backwardOrdering,
ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering,
)
if err != nil {
return
}
// Retrieve the events' contents using their IDs.
events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs)
events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs, filter, true)
return
}
@ -619,7 +620,7 @@ func (d *Database) fetchMissingStateEvents(
) ([]types.StreamEvent, error) {
// Fetch from the events table first so we pick up the stream ID for the
// event.
events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs)
events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, nil, false)
if err != nil {
return nil, err
}
@ -687,6 +688,9 @@ func (d *Database) GetStateDeltas(
// user has ever interacted with — joined to, kicked/banned from, left.
memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
}
return nil, nil, err
}
@ -704,17 +708,23 @@ func (d *Database) GetStateDeltas(
// get all the state events ever (i.e. for all available rooms) between these two positions
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
}
return nil, nil, err
}
state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
}
return nil, nil, err
}
// find out which rooms this user is peeking, if any.
// We do this before joins so any peeks get overwritten
peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r)
if err != nil {
if err != nil && err != sql.ErrNoRows {
return nil, nil, err
}
@ -725,6 +735,9 @@ func (d *Database) GetStateDeltas(
var s []types.StreamEvent
s, err = d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter)
if err != nil {
if err == sql.ErrNoRows {
continue
}
return nil, nil, err
}
state[peek.RoomID] = s
@ -752,6 +765,9 @@ func (d *Database) GetStateDeltas(
var s []types.StreamEvent
s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter)
if err != nil {
if err == sql.ErrNoRows {
continue
}
return nil, nil, err
}
state[roomID] = s
@ -802,6 +818,9 @@ func (d *Database) GetStateDeltasForFullStateSync(
// user has ever interacted with — joined to, kicked/banned from, left.
memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
}
return nil, nil, err
}
@ -818,7 +837,7 @@ func (d *Database) GetStateDeltasForFullStateSync(
deltas := make(map[string]types.StateDelta)
peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r)
if err != nil {
if err != nil && err != sql.ErrNoRows {
return nil, nil, err
}
@ -827,6 +846,9 @@ func (d *Database) GetStateDeltasForFullStateSync(
if !peek.Deleted {
s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter)
if stateErr != nil {
if stateErr == sql.ErrNoRows {
continue
}
return nil, nil, stateErr
}
deltas[peek.RoomID] = types.StateDelta{
@ -840,10 +862,16 @@ func (d *Database) GetStateDeltasForFullStateSync(
// Get all the state events ever between these two positions
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
}
return nil, nil, err
}
state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil
}
return nil, nil, err
}
@ -868,6 +896,9 @@ func (d *Database) GetStateDeltasForFullStateSync(
for _, joinedRoomID := range joinedRoomIDs {
s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilter)
if stateErr != nil {
if stateErr == sql.ErrNoRows {
continue
}
return nil, nil, stateErr
}
deltas[joinedRoomID] = types.StateDelta{

View file

@ -41,23 +41,23 @@ const insertAccountDataSQL = "" +
" ON CONFLICT (user_id, room_id, type) DO UPDATE" +
" SET id = $5"
// further parameters are added by prepareWithFilters
const selectAccountDataInRangeSQL = "" +
"SELECT room_id, type FROM syncapi_account_data_type" +
" WHERE user_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id ASC"
" WHERE user_id = $1 AND id > $2 AND id <= $3"
const selectMaxAccountDataIDSQL = "" +
"SELECT MAX(id) FROM syncapi_account_data_type"
type accountDataStatements struct {
db *sql.DB
streamIDStatements *streamIDStatements
streamIDStatements *StreamIDStatements
insertAccountDataStmt *sql.Stmt
selectMaxAccountDataIDStmt *sql.Stmt
selectAccountDataInRangeStmt *sql.Stmt
}
func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) {
func NewSqliteAccountDataTable(db *sql.DB, streamID *StreamIDStatements) (tables.AccountData, error) {
s := &accountDataStatements{
db: db,
streamIDStatements: streamID,
@ -94,18 +94,24 @@ func (s *accountDataStatements) SelectAccountDataInRange(
ctx context.Context,
userID string,
r types.Range,
accountDataFilterPart *gomatrixserverlib.EventFilter,
filter *gomatrixserverlib.EventFilter,
) (data map[string][]string, err error) {
data = make(map[string][]string)
stmt, params, err := prepareWithFilters(
s.db, nil, selectAccountDataInRangeSQL,
[]interface{}{
userID, r.Low(), r.High(),
},
filter.Senders, filter.NotSenders,
filter.Types, filter.NotTypes,
[]string{}, nil, filter.Limit, FilterOrderAsc)
rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High())
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed")
var entries int
for rows.Next() {
var dataType string
var roomID string
@ -114,31 +120,11 @@ func (s *accountDataStatements) SelectAccountDataInRange(
return
}
// check if we should add this by looking at the filter.
// It would be nice if we could do this in SQL-land, but the mix of variadic
// and positional parameters makes the query annoyingly hard to do, it's easier
// and clearer to do it in Go-land. If there are no filters for [not]types then
// this gets skipped.
for _, includeType := range accountDataFilterPart.Types {
if includeType != dataType { // TODO: wildcard support
continue
}
}
for _, excludeType := range accountDataFilterPart.NotTypes {
if excludeType == dataType { // TODO: wildcard support
continue
}
}
if len(data[roomID]) > 0 {
data[roomID] = append(data[roomID], dataType)
} else {
data[roomID] = []string{dataType}
}
entries++
if entries >= accountDataFilterPart.Limit {
break
}
}
return data, nil

View file

@ -47,15 +47,11 @@ const selectBackwardExtremitiesForRoomSQL = "" +
const deleteBackwardExtremitySQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
const deleteBackwardExtremitiesForRoomSQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1"
type backwardExtremitiesStatements struct {
db *sql.DB
insertBackwardExtremityStmt *sql.Stmt
selectBackwardExtremitiesForRoomStmt *sql.Stmt
deleteBackwardExtremityStmt *sql.Stmt
deleteBackwardExtremitiesForRoomStmt *sql.Stmt
}
func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
@ -75,9 +71,6 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities
if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil {
return nil, err
}
if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil {
return nil, err
}
return s, nil
}
@ -116,10 +109,3 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
return err
}
func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID)
return err
}

View file

@ -90,7 +90,7 @@ const selectEventsWithEventIDsSQL = "" +
type currentRoomStateStatements struct {
db *sql.DB
streamIDStatements *streamIDStatements
streamIDStatements *StreamIDStatements
upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt
deleteRoomStateForRoomStmt *sql.Stmt
@ -100,7 +100,7 @@ type currentRoomStateStatements struct {
selectStateEventStmt *sql.Stmt
}
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
s := &currentRoomStateStatements{
db: db,
streamIDStatements: streamID,
@ -220,7 +220,7 @@ func (s *currentRoomStateStatements) SelectCurrentState(
},
stateFilter.Senders, stateFilter.NotSenders,
stateFilter.Types, stateFilter.NotTypes,
excludeEventIDs, stateFilter.Limit, FilterOrderNone,
excludeEventIDs, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderNone,
)
if err != nil {
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)

View file

@ -25,33 +25,52 @@ const (
// parts.
func prepareWithFilters(
db *sql.DB, txn *sql.Tx, query string, params []interface{},
senders, notsenders, types, nottypes []string, excludeEventIDs []string,
limit int, order FilterOrder,
senders, notsenders, types, nottypes *[]string, excludeEventIDs []string,
containsURL *bool, limit int, order FilterOrder,
) (*sql.Stmt, []interface{}, error) {
offset := len(params)
if count := len(senders); count > 0 {
if senders != nil {
if count := len(*senders); count > 0 {
query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range senders {
for _, v := range *senders {
params, offset = append(params, v), offset+1
}
} else {
query += ` AND sender = ""`
}
if count := len(notsenders); count > 0 {
}
if notsenders != nil {
if count := len(*notsenders); count > 0 {
query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range notsenders {
for _, v := range *notsenders {
params, offset = append(params, v), offset+1
}
} else {
query += ` AND sender NOT = ""`
}
if count := len(types); count > 0 {
}
if types != nil {
if count := len(*types); count > 0 {
query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range types {
for _, v := range *types {
params, offset = append(params, v), offset+1
}
} else {
query += ` AND type = ""`
}
if count := len(nottypes); count > 0 {
}
if nottypes != nil {
if count := len(*nottypes); count > 0 {
query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range nottypes {
for _, v := range *nottypes {
params, offset = append(params, v), offset+1
}
} else {
query += ` AND type NOT = ""`
}
}
if containsURL != nil {
query += fmt.Sprintf(" AND contains_url = %v", *containsURL)
}
if count := len(excludeEventIDs); count > 0 {
query += " AND event_id NOT IN " + sqlutil.QueryVariadicOffset(count, offset)

View file

@ -59,14 +59,14 @@ const selectMaxInviteIDSQL = "" +
type inviteEventsStatements struct {
db *sql.DB
streamIDStatements *streamIDStatements
streamIDStatements *StreamIDStatements
insertInviteEventStmt *sql.Stmt
selectInviteEventsInRangeStmt *sql.Stmt
deleteInviteEventStmt *sql.Stmt
selectMaxInviteIDStmt *sql.Stmt
}
func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) {
func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Invites, error) {
s := &inviteEventsStatements{
db: db,
streamIDStatements: streamID,

View file

@ -18,7 +18,6 @@ import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
@ -57,12 +56,6 @@ const upsertMembershipSQL = "" +
" ON CONFLICT (room_id, user_id, membership)" +
" DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6"
const selectMembershipSQL = "" +
"SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" +
" WHERE room_id = $1 AND user_id = $2 AND membership IN ($3)" +
" ORDER BY stream_pos DESC" +
" LIMIT 1"
const selectMembershipCountSQL = "" +
"SELECT COUNT(*) FROM (" +
" SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" +
@ -111,22 +104,6 @@ func (s *membershipsStatements) UpsertMembership(
return err
}
func (s *membershipsStatements) SelectMembership(
ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string,
) (eventID string, streamPos, topologyPos types.StreamPosition, err error) {
params := []interface{}{roomID, userID}
for _, membership := range memberships {
params = append(params, membership)
}
orig := strings.Replace(selectMembershipSQL, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1)
stmt, err := s.db.Prepare(orig)
if err != nil {
return "", 0, 0, err
}
err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos)
return
}
func (s *membershipsStatements) SelectMembershipCount(
ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition,
) (count int, err error) {

View file

@ -58,7 +58,7 @@ const insertEventSQL = "" +
"ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)"
const selectEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1"
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id IN ($1)"
const selectRecentEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
@ -111,9 +111,8 @@ const selectContextAfterEventSQL = "" +
type outputRoomEventsStatements struct {
db *sql.DB
streamIDStatements *streamIDStatements
streamIDStatements *StreamIDStatements
insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt
selectMaxEventIDStmt *sql.Stmt
updateEventJSONStmt *sql.Stmt
deleteEventsForRoomStmt *sql.Stmt
@ -122,7 +121,7 @@ type outputRoomEventsStatements struct {
selectContextAfterEventStmt *sql.Stmt
}
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Events, error) {
s := &outputRoomEventsStatements{
db: db,
streamIDStatements: streamID,
@ -133,7 +132,6 @@ func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Even
}
return s, sqlutil.StatementList{
{&s.insertEventStmt, insertEventSQL},
{&s.selectEventsStmt, selectEventsSQL},
{&s.selectMaxEventIDStmt, selectMaxEventIDSQL},
{&s.updateEventJSONStmt, updateEventJSONSQL},
{&s.deleteEventsForRoomStmt, deleteEventsForRoomSQL},
@ -170,7 +168,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
s.db, txn, stmtSQL, inputParams,
stateFilter.Senders, stateFilter.NotSenders,
stateFilter.Types, stateFilter.NotTypes,
nil, stateFilter.Limit, FilterOrderAsc,
nil, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderAsc,
)
if err != nil {
return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err)
@ -279,7 +277,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
// Parse content as JSON and search for an "url" key
containsURL := false
var content map[string]interface{}
if json.Unmarshal(event.Content(), &content) != nil {
if json.Unmarshal(event.Content(), &content) == nil {
// Set containsURL to true if url is present
_, containsURL = content["url"]
}
@ -347,7 +345,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
},
eventFilter.Senders, eventFilter.NotSenders,
eventFilter.Types, eventFilter.NotTypes,
nil, eventFilter.Limit+1, FilterOrderDesc,
nil, eventFilter.ContainsURL, eventFilter.Limit+1, FilterOrderDesc,
)
if err != nil {
return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err)
@ -395,7 +393,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
},
eventFilter.Senders, eventFilter.NotSenders,
eventFilter.Types, eventFilter.NotTypes,
nil, eventFilter.Limit, FilterOrderAsc,
nil, eventFilter.ContainsURL, eventFilter.Limit, FilterOrderAsc,
)
if err != nil {
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
@ -421,21 +419,50 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
// selectEvents returns the events for the given event IDs. If an event is
// missing from the database, it will be omitted.
func (s *outputRoomEventsStatements) SelectEvents(
ctx context.Context, txn *sql.Tx, eventIDs []string,
ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool,
) ([]types.StreamEvent, error) {
var returnEvents []types.StreamEvent
stmt := sqlutil.TxStmt(txn, s.selectEventsStmt)
for _, eventID := range eventIDs {
rows, err := stmt.QueryContext(ctx, eventID)
iEventIDs := make([]interface{}, len(eventIDs))
for i := range eventIDs {
iEventIDs[i] = eventIDs[i]
}
selectSQL := strings.Replace(selectEventsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1)
if filter == nil {
filter = &gomatrixserverlib.RoomEventFilter{Limit: 20}
}
stmt, params, err := prepareWithFilters(
s.db, txn, selectSQL, iEventIDs,
filter.Senders, filter.NotSenders,
filter.Types, filter.NotTypes,
nil, filter.ContainsURL, filter.Limit, FilterOrderAsc,
)
if err != nil {
return nil, err
}
if streamEvents, err := rowsToStreamEvents(rows); err == nil {
returnEvents = append(returnEvents, streamEvents...)
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
streamEvents, err := rowsToStreamEvents(rows)
if err != nil {
return nil, err
}
if preserveOrder {
var returnEvents []types.StreamEvent
eventMap := make(map[string]types.StreamEvent)
for _, ev := range streamEvents {
eventMap[ev.EventID()] = ev
}
for _, eventID := range eventIDs {
ev, ok := eventMap[eventID]
if ok {
returnEvents = append(returnEvents, ev)
}
internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
}
return returnEvents, nil
}
return streamEvents, nil
}
func (s *outputRoomEventsStatements) DeleteEventsForRoom(
@ -507,7 +534,7 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
},
filter.Senders, filter.NotSenders,
filter.Types, filter.NotTypes,
nil, filter.Limit, FilterOrderDesc,
nil, filter.ContainsURL, filter.Limit, FilterOrderDesc,
)
rows, err := stmt.QueryContext(ctx, params...)
@ -543,7 +570,7 @@ func (s *outputRoomEventsStatements) SelectContextAfterEvent(
},
filter.Senders, filter.NotSenders,
filter.Types, filter.NotTypes,
nil, filter.Limit, FilterOrderAsc,
nil, filter.ContainsURL, filter.Limit, FilterOrderAsc,
)
rows, err := stmt.QueryContext(ctx, params...)

View file

@ -78,7 +78,6 @@ type outputRoomEventsTopologyStatements struct {
selectEventIDsInRangeDESCStmt *sql.Stmt
selectPositionInTopologyStmt *sql.Stmt
selectMaxPositionInTopologyStmt *sql.Stmt
deleteTopologyForRoomStmt *sql.Stmt
selectStreamToTopologicalPositionAscStmt *sql.Stmt
selectStreamToTopologicalPositionDescStmt *sql.Stmt
}
@ -191,10 +190,3 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
return
}
func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID)
return err
}

View file

@ -66,7 +66,7 @@ const selectMaxPeekIDSQL = "" +
type peekStatements struct {
db *sql.DB
streamIDStatements *streamIDStatements
streamIDStatements *StreamIDStatements
insertPeekStmt *sql.Stmt
deletePeekStmt *sql.Stmt
deletePeeksStmt *sql.Stmt
@ -75,7 +75,7 @@ type peekStatements struct {
selectMaxPeekIDStmt *sql.Stmt
}
func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks, error) {
func NewSqlitePeeksTable(db *sql.DB, streamID *StreamIDStatements) (tables.Peeks, error) {
_, err := db.Exec(peeksSchema)
if err != nil {
return nil, err

View file

@ -75,7 +75,7 @@ const selectPresenceAfter = "" +
type presenceStatements struct {
db *sql.DB
streamIDStatements *streamIDStatements
streamIDStatements *StreamIDStatements
upsertPresenceStmt *sql.Stmt
upsertPresenceFromSyncStmt *sql.Stmt
selectPresenceForUsersStmt *sql.Stmt
@ -83,7 +83,7 @@ type presenceStatements struct {
selectPresenceAfterStmt *sql.Stmt
}
func NewSqlitePresenceTable(db *sql.DB, streamID *streamIDStatements) (*presenceStatements, error) {
func NewSqlitePresenceTable(db *sql.DB, streamID *StreamIDStatements) (*presenceStatements, error) {
_, err := db.Exec(presenceSchema)
if err != nil {
return nil, err

View file

@ -59,13 +59,13 @@ const selectMaxReceiptIDSQL = "" +
type receiptStatements struct {
db *sql.DB
streamIDStatements *streamIDStatements
streamIDStatements *StreamIDStatements
upsertReceipt *sql.Stmt
selectRoomReceipts *sql.Stmt
selectMaxReceiptID *sql.Stmt
}
func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Receipts, error) {
func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Receipts, error) {
_, err := db.Exec(receiptsSchema)
if err != nil {
return nil, err

View file

@ -32,12 +32,12 @@ const increaseStreamIDStmt = "" +
"UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" +
" RETURNING stream_id"
type streamIDStatements struct {
type StreamIDStatements struct {
db *sql.DB
increaseStreamIDStmt *sql.Stmt
}
func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
func (s *StreamIDStatements) Prepare(db *sql.DB) (err error) {
s.db = db
_, err = db.Exec(streamIDTableSchema)
if err != nil {
@ -49,31 +49,31 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
return
}
func (s *streamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
func (s *StreamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "global").Scan(&pos)
return
}
func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
func (s *StreamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "receipt").Scan(&pos)
return
}
func (s *streamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
func (s *StreamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "invite").Scan(&pos)
return
}
func (s *streamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
func (s *StreamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "accountdata").Scan(&pos)
return
}
func (s *streamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
func (s *StreamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "presence").Scan(&pos)
return

View file

@ -30,7 +30,7 @@ type SyncServerDatasource struct {
shared.Database
db *sql.DB
writer sqlutil.Writer
streamID streamIDStatements
streamID StreamIDStatements
}
// NewDatabase creates a new sync server database
@ -49,7 +49,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
}
func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) {
if err = d.streamID.prepare(d.db); err != nil {
if err = d.streamID.Prepare(d.db); err != nil {
return err
}
accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID)

View file

@ -3,6 +3,7 @@ package storage_test
import (
"context"
"fmt"
"reflect"
"testing"
"github.com/matrix-org/dendrite/setup/config"
@ -38,7 +39,7 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
if err != nil {
t.Fatalf("WriteEvent failed: %s", err)
}
fmt.Println("Event ID", ev.EventID(), " spos=", pos, "depth=", ev.Depth())
t.Logf("Event ID %s spos=%v depth=%v", ev.EventID(), pos, ev.Depth())
positions = append(positions, pos)
}
return
@ -46,7 +47,6 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
func TestWriteEvents(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
t.Parallel()
alice := test.NewUser()
r := test.NewRoom(t, alice)
db, close := MustCreateDatabase(t, dbType)
@ -61,12 +61,18 @@ func TestRecentEventsPDU(t *testing.T) {
db, close := MustCreateDatabase(t, dbType)
defer close()
alice := test.NewUser()
var filter gomatrixserverlib.RoomEventFilter
filter.Limit = 100
// dummy room to make sure SQL queries are filtering on room ID
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
// actual test room
r := test.NewRoom(t, alice)
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hi"})
events := r.Events()
positions := MustWriteEvents(t, db, events)
// dummy room to make sure SQL queries are filtering on room ID
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
latest, err := db.MaxStreamPositionForPDUs(ctx)
if err != nil {
t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err)
@ -76,69 +82,63 @@ func TestRecentEventsPDU(t *testing.T) {
Name string
From types.StreamPosition
To types.StreamPosition
Limit int
ReverseOrder bool
WantEvents []*gomatrixserverlib.HeaderedEvent
WantLimited bool
}{
// The purpose of this test is to make sure that incremental syncs are including up to the latest events.
// It's a basic sanity test that sync works. It creates a `since` token that is on the penultimate event.
// It's a basic sanity test that sync works. It creates a streaming position that is on the penultimate event.
// It makes sure the response includes the final event.
{
Name: "IncrementalSync penultimate",
Name: "penultimate",
From: positions[len(positions)-2], // pretend we are at the penultimate event
To: latest,
Limit: 100,
WantEvents: events[len(events)-1:],
WantLimited: false,
},
/*
// The purpose of this test is to check that passing a `numRecentEventsPerRoom` correctly limits the
// number of returned events. This is critical for big rooms hence the test here.
// The purpose of this test is to check that limits can be applied and work.
// This is critical for big rooms hence the test here.
{
Name: "IncrementalSync limited",
DoSync: func() (*types.Response, error) {
from := types.StreamingToken{ // pretend we are 10 events behind
PDUPosition: positions[len(positions)-11],
}
res := types.NewResponse()
// limit is set to 5
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
Name: "limited",
From: 0,
To: latest,
Limit: 1,
WantEvents: events[len(events)-1:],
WantLimited: true,
},
// want the last 5 events, NOT the last 10.
WantTimeline: events[len(events)-5:],
},
// The purpose of this test is to check that CompleteSync returns all the current state as well as
// honouring the `numRecentEventsPerRoom` value
// The purpose of this test is to check that we can return every event with a high
// enough limit
{
Name: "CompleteSync limited",
DoSync: func() (*types.Response, error) {
res := types.NewResponse()
// limit set to 5
return db.CompleteSync(ctx, res, testUserDeviceA, 5)
Name: "large limited",
From: 0,
To: latest,
Limit: 100,
WantEvents: events,
WantLimited: false,
},
// want the last 5 events
WantTimeline: events[len(events)-5:],
// want all state for the room
WantState: state,
},
// The purpose of this test is to check that CompleteSync can return everything with a high enough
// `numRecentEventsPerRoom`.
// The purpose of this test is to check that we can return events in reverse order
{
Name: "CompleteSync",
DoSync: func() (*types.Response, error) {
res := types.NewResponse()
return db.CompleteSync(ctx, res, testUserDeviceA, len(events)+1)
Name: "reverse",
From: positions[len(positions)-3], // 2 events back
To: latest,
Limit: 100,
ReverseOrder: true,
WantEvents: test.Reversed(events[len(events)-2:]),
WantLimited: false,
},
WantTimeline: events,
// We want no state at all as that field in /sync is the delta between the token (beginning of time)
// and the START of the timeline.
}, */
}
for _, tc := range testCases {
for i := range testCases {
tc := testCases[i]
t.Run(tc.Name, func(st *testing.T) {
var filter gomatrixserverlib.RoomEventFilter
filter.Limit = tc.Limit
gotEvents, limited, err := db.RecentEvents(ctx, r.ID, types.Range{
From: tc.From,
To: tc.To,
}, &filter, true, true)
}, &filter, !tc.ReverseOrder, true)
if err != nil {
st.Fatalf("failed to do sync: %s", err)
}
@ -148,100 +148,49 @@ func TestRecentEventsPDU(t *testing.T) {
if len(gotEvents) != len(tc.WantEvents) {
st.Errorf("got %d events, want %d", len(gotEvents), len(tc.WantEvents))
}
for j := range gotEvents {
if !reflect.DeepEqual(gotEvents[j].JSON(), tc.WantEvents[j].JSON()) {
st.Errorf("event %d got %s want %s", j, string(gotEvents[j].JSON()), string(tc.WantEvents[j].JSON()))
}
}
})
}
})
}
/*
func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
t.Parallel()
db := MustCreateDatabase(t)
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
positions := MustWriteEvents(t, db, events)
latest, err := db.SyncPosition(ctx)
if err != nil {
t.Fatalf("failed to get SyncPosition: %s", err)
}
from := types.StreamingToken{
PDUPosition: positions[len(positions)-2],
}
res := types.NewResponse()
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
if err != nil {
t.Fatalf("failed to IncrementalSync with latest token")
}
roomRes, ok := res.Rooms.Join[testRoomID]
if !ok {
t.Fatalf("IncrementalSync response missing room %s - response: %+v", testRoomID, res)
}
// returns the last event "Message 10"
assertEventsEqual(t, "IncrementalSync Timeline", false, roomRes.Timeline.Events, reversed(events[len(events)-1:]))
prev := roomRes.Timeline.PrevBatch.String()
if prev == "" {
t.Fatalf("IncrementalSync expected prev_batch token")
}
prevBatchToken, err := types.NewTopologyTokenFromString(prev)
if err != nil {
t.Fatalf("failed to NewTopologyTokenFromString : %s", err)
}
// backpaginate 5 messages starting at the latest position.
// head towards the beginning of time
to := types.TopologyToken{}
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &prevBatchToken, &to, testRoomID, 5, true)
if err != nil {
t.Fatalf("GetEventsInRange returned an error: %s", err)
}
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-6:len(events)-1]))
}
// The purpose of this test is to ensure that backfill does indeed go backwards, using a stream token.
func TestGetEventsInRangeWithStreamToken(t *testing.T) {
t.Parallel()
db := MustCreateDatabase(t)
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
MustWriteEvents(t, db, events)
latest, err := db.SyncPosition(ctx)
if err != nil {
t.Fatalf("failed to get SyncPosition: %s", err)
}
// head towards the beginning of time
to := types.StreamingToken{}
// backpaginate 5 messages starting at the latest position.
paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true)
if err != nil {
t.Fatalf("GetEventsInRange returned an error: %s", err)
}
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:]))
}
// The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token
func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
t.Parallel()
db := MustCreateDatabase(t)
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
MustWriteEvents(t, db, events)
from, err := db.MaxTopologicalPosition(ctx, testRoomID)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := MustCreateDatabase(t, dbType)
defer close()
alice := test.NewUser()
r := test.NewRoom(t, alice)
for i := 0; i < 10; i++ {
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)})
}
events := r.Events()
_ = MustWriteEvents(t, db, events)
from, err := db.MaxTopologicalPosition(ctx, r.ID)
if err != nil {
t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
}
t.Logf("max topo pos = %+v", from)
// head towards the beginning of time
to := types.TopologyToken{}
// backpaginate 5 messages starting at the latest position.
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, testRoomID, 5, true)
filter := &gomatrixserverlib.RoomEventFilter{Limit: 5}
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true)
if err != nil {
t.Fatalf("GetEventsInRange returned an error: %s", err)
t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
}
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:]))
gots := db.StreamEventsToEvents(nil, paginatedEvents)
test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
})
}
/*
// The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth.
// For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent
// will appear FIRST when going backwards. This test creates a DAG like:
@ -651,12 +600,4 @@ func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *typ
tok.Decrement()
return &tok
}
func reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
for i := 0; i < len(in); i++ {
out[i] = in[len(in)-i-1]
}
return out
}
*/

View file

@ -59,7 +59,7 @@ type Events interface {
SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
// SelectEarlyEvents returns the earliest events in the given room.
SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error)
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error)
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool) ([]types.StreamEvent, error)
UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error
// DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely.
DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
@ -84,8 +84,6 @@ type Topology interface {
SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error)
// SelectMaxPositionInTopology returns the event which has the highest depth, and if there are multiple, the event with the highest stream position.
SelectMaxPositionInTopology(ctx context.Context, txn *sql.Tx, roomID string) (depth types.StreamPosition, spos types.StreamPosition, err error)
// DeleteTopologyForRoom removes all topological information for a room. This should only be done when removing the room entirely.
DeleteTopologyForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
// SelectStreamToTopologicalPosition converts a stream position to a topological position by finding the nearest topological position in the room.
SelectStreamToTopologicalPosition(ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, forward bool) (topoPos types.StreamPosition, err error)
}
@ -132,8 +130,6 @@ type BackwardsExtremities interface {
SelectBackwardExtremitiesForRoom(ctx context.Context, roomID string) (bwExtrems map[string][]string, err error)
// DeleteBackwardExtremity removes a backwards extremity for a room, if one existed.
DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error)
// DeleteBackwardExtremitiesFoorRoomID removes all backward extremities for a room. This should only be done when removing the room entirely.
DeleteBackwardExtremitiesForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
}
// SendToDevice tracks send-to-device messages which are sent to individual
@ -173,7 +169,6 @@ type Receipts interface {
type Memberships interface {
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string) (eventID string, streamPos, topologyPos types.StreamPosition, err error)
SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error)
}

View file

@ -0,0 +1,105 @@
package tables_test
import (
"context"
"database/sql"
"fmt"
"reflect"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
)
func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events, *sql.DB, func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
})
if err != nil {
t.Fatalf("failed to open db: %s", err)
}
var tab tables.Events
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresEventsTable(db)
case test.DBTypeSQLite:
var stream sqlite3.StreamIDStatements
if err = stream.Prepare(db); err != nil {
t.Fatalf("failed to prepare stream stmts: %s", err)
}
tab, err = sqlite3.NewSqliteEventsTable(db, &stream)
}
if err != nil {
t.Fatalf("failed to make new table: %s", err)
}
return tab, db, close
}
func TestOutputRoomEventsTable(t *testing.T) {
ctx := context.Background()
alice := test.NewUser()
room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, db, close := newOutputRoomEventsTable(t, dbType)
defer close()
events := room.Events()
err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
for _, ev := range events {
_, err := tab.InsertEvent(ctx, txn, ev, nil, nil, nil, false)
if err != nil {
return fmt.Errorf("failed to InsertEvent: %s", err)
}
}
// order = 2,0,3,1
wantEventIDs := []string{
events[2].EventID(), events[0].EventID(), events[3].EventID(), events[1].EventID(),
}
gotEvents, err := tab.SelectEvents(ctx, txn, wantEventIDs, nil, true)
if err != nil {
return fmt.Errorf("failed to SelectEvents: %s", err)
}
gotEventIDs := make([]string, len(gotEvents))
for i := range gotEvents {
gotEventIDs[i] = gotEvents[i].EventID()
}
if !reflect.DeepEqual(gotEventIDs, wantEventIDs) {
return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventIDs)
}
// Test that contains_url is correctly populated
urlEv := room.CreateEvent(t, alice, "m.text", map[string]interface{}{
"body": "test.txt",
"url": "mxc://test.txt",
})
if _, err = tab.InsertEvent(ctx, txn, urlEv, nil, nil, nil, false); err != nil {
return fmt.Errorf("failed to InsertEvent: %s", err)
}
wantEventID := []string{urlEv.EventID()}
t := true
gotEvents, err = tab.SelectEvents(ctx, txn, wantEventID, &gomatrixserverlib.RoomEventFilter{Limit: 1, ContainsURL: &t}, true)
if err != nil {
return fmt.Errorf("failed to SelectEvents: %s", err)
}
gotEventIDs = make([]string, len(gotEvents))
for i := range gotEvents {
gotEventIDs[i] = gotEvents[i].EventID()
}
if !reflect.DeepEqual(gotEventIDs, wantEventID) {
return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventID)
}
return nil
})
if err != nil {
t.Fatalf("err: %s", err)
}
})
}

View file

@ -0,0 +1,91 @@
package tables_test
import (
"context"
"database/sql"
"fmt"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test"
)
func newTopologyTable(t *testing.T, dbType test.DBType) (tables.Topology, *sql.DB, func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
})
if err != nil {
t.Fatalf("failed to open db: %s", err)
}
var tab tables.Topology
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresTopologyTable(db)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSqliteTopologyTable(db)
}
if err != nil {
t.Fatalf("failed to make new table: %s", err)
}
return tab, db, close
}
func TestTopologyTable(t *testing.T) {
ctx := context.Background()
alice := test.NewUser()
room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, db, close := newTopologyTable(t, dbType)
defer close()
events := room.Events()
err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
var highestPos types.StreamPosition
for i, ev := range events {
topoPos, err := tab.InsertEventInTopology(ctx, txn, ev, types.StreamPosition(i))
if err != nil {
return fmt.Errorf("failed to InsertEventInTopology: %s", err)
}
// topo pos = depth, depth starts at 1, hence 1+i
if topoPos != types.StreamPosition(1+i) {
return fmt.Errorf("got topo pos %d want %d", topoPos, 1+i)
}
highestPos = topoPos + 1
}
// check ordering works without limit
eventIDs, err := tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, true)
if err != nil {
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
}
test.AssertEventIDsEqual(t, eventIDs, events[:])
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, false)
if err != nil {
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
}
test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[:]))
// check ordering works with limit
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, true)
if err != nil {
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
}
test.AssertEventIDsEqual(t, eventIDs, events[:3])
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, false)
if err != nil {
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
}
test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[len(events)-3:]))
return nil
})
if err != nil {
t.Fatalf("err: %s", err)
}
})
}

View file

@ -3,9 +3,11 @@ package streams
import (
"context"
"database/sql"
"fmt"
"sync"
"time"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
@ -26,7 +28,8 @@ type PDUStreamProvider struct {
tasks chan func()
workers atomic.Int32
userAPI userapi.UserInternalAPI
// userID+deviceID -> lazy loading cache
lazyLoadCache *caching.LazyLoadCache
}
func (p *PDUStreamProvider) worker() {
@ -188,7 +191,7 @@ func (p *PDUStreamProvider) IncrementalSync(
newPos = from
for _, delta := range stateDeltas {
var pos types.StreamPosition
if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, r, delta, &eventFilter, req.Response); err != nil {
if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, r, delta, &eventFilter, &stateFilter, req.Response); err != nil {
req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed")
return to
}
@ -203,12 +206,14 @@ func (p *PDUStreamProvider) IncrementalSync(
return newPos
}
// nolint:gocyclo
func (p *PDUStreamProvider) addRoomDeltaToResponse(
ctx context.Context,
device *userapi.Device,
r types.Range,
delta types.StateDelta,
eventFilter *gomatrixserverlib.RoomEventFilter,
stateFilter *gomatrixserverlib.StateFilter,
res *types.Response,
) (types.StreamPosition, error) {
if delta.MembershipPos > 0 && delta.Membership == gomatrixserverlib.Leave {
@ -225,13 +230,16 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
eventFilter, true, true,
)
if err != nil {
return r.From, err
if err == sql.ErrNoRows {
return r.To, nil
}
return r.From, fmt.Errorf("p.DB.RecentEvents: %w", err)
}
recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents)
delta.StateEvents = removeDuplicates(delta.StateEvents, recentEvents) // roll back
prevBatch, err := p.DB.GetBackwardTopologyPos(ctx, recentStreamEvents)
if err != nil {
return r.From, err
return r.From, fmt.Errorf("p.DB.GetBackwardTopologyPos: %w", err)
}
// If we didn't return any events at all then don't bother doing anything else.
@ -247,7 +255,8 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
// room that were returned.
latestPosition := r.To
updateLatestPosition := func(mostRecentEventID string) {
if _, pos, err := p.DB.PositionInTopology(ctx, mostRecentEventID); err == nil {
var pos types.StreamPosition
if _, pos, err = p.DB.PositionInTopology(ctx, mostRecentEventID); err == nil {
switch {
case r.Backwards && pos > latestPosition:
fallthrough
@ -263,6 +272,16 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID())
}
if stateFilter.LazyLoadMembers {
delta.StateEvents, err = p.lazyLoadMembers(
ctx, delta.RoomID, true, limited, stateFilter.IncludeRedundantMembers,
device, recentEvents, delta.StateEvents,
)
if err != nil && err != sql.ErrNoRows {
return r.From, fmt.Errorf("p.lazyLoadMembers: %w", err)
}
}
hasMembershipChange := false
for _, recentEvent := range recentStreamEvents {
if recentEvent.Type() == gomatrixserverlib.MRoomMember && recentEvent.StateKey() != nil {
@ -322,12 +341,16 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
wantFullState bool,
device *userapi.Device,
) (jr *types.JoinResponse, err error) {
jr = types.NewJoinResponse()
// TODO: When filters are added, we may need to call this multiple times to get enough events.
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
recentStreamEvents, limited, err := p.DB.RecentEvents(
ctx, roomID, r, eventFilter, true, true,
)
if err != nil {
if err == sql.ErrNoRows {
return jr, nil
}
return
}
@ -402,7 +425,20 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
// "Can sync a room with a message with a transaction id" - which does a complete sync to check.
recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents)
jr = types.NewJoinResponse()
if stateFilter.LazyLoadMembers {
if err != nil {
return nil, err
}
stateEvents, err = p.lazyLoadMembers(ctx, roomID,
false, limited, stateFilter.IncludeRedundantMembers,
device, recentEvents, stateEvents,
)
if err != nil && err != sql.ErrNoRows {
return nil, err
}
}
jr.Summary.JoinedMemberCount = &joinedCount
jr.Summary.InvitedMemberCount = &invitedCount
jr.Timeline.PrevBatch = prevBatch
@ -412,6 +448,69 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
return jr, nil
}
func (p *PDUStreamProvider) lazyLoadMembers(
ctx context.Context, roomID string,
incremental, limited, includeRedundant bool,
device *userapi.Device,
timelineEvents, stateEvents []*gomatrixserverlib.HeaderedEvent,
) ([]*gomatrixserverlib.HeaderedEvent, error) {
if len(timelineEvents) == 0 {
return stateEvents, nil
}
// Work out which memberships to include
timelineUsers := make(map[string]struct{})
if !incremental {
timelineUsers[device.UserID] = struct{}{}
}
// Add all users the client doesn't know about yet to a list
for _, event := range timelineEvents {
// Membership is not yet cached, add it to the list
if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, event.Sender()); !ok {
timelineUsers[event.Sender()] = struct{}{}
}
}
// Preallocate with the same amount, even if it will end up with fewer values
newStateEvents := make([]*gomatrixserverlib.HeaderedEvent, 0, len(stateEvents))
// Remove existing membership events we don't care about, e.g. users not in the timeline.events
for _, event := range stateEvents {
if event.Type() == gomatrixserverlib.MRoomMember && event.StateKey() != nil {
// If this is a gapped incremental sync, we still want this membership
isGappedIncremental := limited && incremental
// We want this users membership event, keep it in the list
_, ok := timelineUsers[event.Sender()]
wantMembership := ok || isGappedIncremental
if wantMembership {
newStateEvents = append(newStateEvents, event)
if !includeRedundant {
p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, event.Sender(), event.EventID())
}
delete(timelineUsers, event.Sender())
}
} else {
newStateEvents = append(newStateEvents, event)
}
}
wantUsers := make([]string, 0, len(timelineUsers))
for userID := range timelineUsers {
wantUsers = append(wantUsers, userID)
}
// Query missing membership events
memberships, err := p.DB.GetStateEventsForRoom(ctx, roomID, &gomatrixserverlib.StateFilter{
Limit: 100,
Senders: &wantUsers,
Types: &[]string{gomatrixserverlib.MRoomMember},
})
if err != nil {
return stateEvents, err
}
// cache the membership events
for _, membership := range memberships {
p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, membership.Sender(), membership.EventID())
}
stateEvents = append(newStateEvents, memberships...)
return stateEvents, nil
}
// addIgnoredUsersToFilter adds ignored users to the eventfilter and
// the syncreq itself for further use in streams.
func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error {
@ -423,8 +522,12 @@ func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *ty
return err
}
req.IgnoredUsers = *ignores
userList := make([]string, 0, len(ignores.List))
for userID := range ignores.List {
eventFilter.NotSenders = append(eventFilter.NotSenders, userID)
userList = append(userList, userID)
}
if len(userList) > 0 {
eventFilter.NotSenders = &userList
}
return nil
}

View file

@ -27,12 +27,12 @@ type Streams struct {
func NewSyncStreamProviders(
d storage.Database, userAPI userapi.UserInternalAPI,
rsAPI rsapi.RoomserverInternalAPI, keyAPI keyapi.KeyInternalAPI,
eduCache *caching.EDUCache, notifier *notifier.Notifier,
eduCache *caching.EDUCache, lazyLoadCache *caching.LazyLoadCache, notifier *notifier.Notifier,
) *Streams {
streams := &Streams{
PDUStreamProvider: &PDUStreamProvider{
StreamProvider: StreamProvider{DB: d},
userAPI: userAPI,
lazyLoadCache: lazyLoadCache,
},
TypingStreamProvider: &TypingStreamProvider{
StreamProvider: StreamProvider{DB: d},

View file

@ -15,6 +15,7 @@
package sync
import (
"database/sql"
"encoding/json"
"fmt"
"net/http"
@ -60,10 +61,10 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return nil, fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err != nil {
if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err != nil && err != sql.ErrNoRows {
util.GetLogger(req.Context()).WithError(err).Error("syncDB.GetFilter failed")
return nil, fmt.Errorf("syncDB.GetFilter: %w", err)
} else {
} else if f != nil {
filter = *f
}
}

View file

@ -57,8 +57,12 @@ func AddPublicRoutes(
}
eduCache := caching.NewTypingCache()
lazyLoadCache, err := caching.NewLazyLoadCache()
if err != nil {
logrus.WithError(err).Panicf("failed to create lazy loading cache")
}
notifier := notifier.NewNotifier()
streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, notifier)
streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, lazyLoadCache, notifier)
notifier.SetCurrentPosition(streams.Latest(context.Background()))
if err = notifier.Load(context.Background(), syncDB); err != nil {
logrus.WithError(err).Panicf("failed to load notifier ")

View file

@ -312,10 +312,10 @@ Inbound federation can return events
Inbound federation can return missing events for world_readable visibility
Inbound federation can return missing events for invite visibility
Inbound federation can get public room list
POST /rooms/:room_id/redact/:event_id as power user redacts message
POST /rooms/:room_id/redact/:event_id as original message sender redacts message
POST /rooms/:room_id/redact/:event_id as random user does not redact message
POST /redact disallows redaction of event in different room
PUT /rooms/:room_id/redact/:event_id/:txn_id as power user redacts message
PUT /rooms/:room_id/redact/:event_id/:txn_id as original message sender redacts message
PUT /rooms/:room_id/redact/:event_id/:txn_id as random user does not redact message
PUT /redact disallows redaction of event in different room
An event which redacts itself should be ignored
A pair of events which redact each other should be ignored
Redaction of a redaction redacts the redaction reason
@ -697,3 +697,16 @@ Room state after a rejected state event is the same as before
Ignore user in existing room
Ignore invite in full sync
Ignore invite in incremental sync
A filtered timeline reaches its limit
A change to displayname should not result in a full state sync
Can fetch images in room
The only membership state included in an initial sync is for all the senders in the timeline
The only membership state included in an incremental sync is for senders in the timeline
Old members are included in gappy incr LL sync if they start speaking
We do send redundant membership state across incremental syncs if asked
Rejecting invite over federation doesn't break incremental /sync
Gapped incremental syncs include all state changes
Old leaves are present in gapped incremental syncs
Leaves are present in non-gapped incremental syncs
Members from the gap are included in gappy incr LL sync
Presence can be set from sync

View file

@ -15,12 +15,16 @@
package test
import (
"crypto/sha256"
"database/sql"
"encoding/hex"
"fmt"
"os"
"os/exec"
"os/user"
"testing"
"github.com/lib/pq"
)
type DBType int
@ -30,7 +34,7 @@ var DBTypePostgres DBType = 2
var Quiet = false
func createLocalDB(dbName string) string {
func createLocalDB(dbName string) {
if !Quiet {
fmt.Println("Note: tests require a postgres install accessible to the current user")
}
@ -43,7 +47,29 @@ func createLocalDB(dbName string) string {
if err != nil && !Quiet {
fmt.Println("createLocalDB returned error:", err)
}
return dbName
}
func createRemoteDB(t *testing.T, dbName, user, connStr string) {
db, err := sql.Open("postgres", connStr+" dbname=postgres")
if err != nil {
t.Fatalf("failed to open postgres conn with connstr=%s : %s", connStr, err)
}
_, err = db.Exec(fmt.Sprintf(`CREATE DATABASE %s;`, dbName))
if err != nil {
pqErr, ok := err.(*pq.Error)
if !ok {
t.Fatalf("failed to CREATE DATABASE: %s", err)
}
// we ignore duplicate database error as we expect this
if pqErr.Code != "42P04" {
t.Fatalf("failed to CREATE DATABASE with code=%s msg=%s", pqErr.Code, pqErr.Message)
}
}
_, err = db.Exec(fmt.Sprintf(`GRANT ALL PRIVILEGES ON DATABASE %s TO %s`, dbName, user))
if err != nil {
t.Fatalf("failed to GRANT: %s", err)
}
_ = db.Close()
}
func currentUser() string {
@ -64,6 +90,7 @@ func currentUser() string {
// TODO: namespace for concurrent package tests
func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, close func()) {
if dbType == DBTypeSQLite {
// this will be made in the current working directory which namespaces concurrent package runs correctly
dbname := "dendrite_test.db"
return fmt.Sprintf("file:%s", dbname), func() {
err := os.Remove(dbname)
@ -79,13 +106,9 @@ func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, clo
if user == "" {
user = currentUser()
}
dbName := os.Getenv("POSTGRES_DB")
if dbName == "" {
dbName = createLocalDB("dendrite_test")
}
connStr = fmt.Sprintf(
"user=%s dbname=%s sslmode=disable",
user, dbName,
"user=%s sslmode=disable",
user,
)
// optional vars, used in CI
password := os.Getenv("POSTGRES_PASSWORD")
@ -97,6 +120,25 @@ func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, clo
connStr += fmt.Sprintf(" host=%s", host)
}
// superuser database
postgresDB := os.Getenv("POSTGRES_DB")
// we cannot use 'dendrite_test' here else 2x concurrently running packages will try to use the same db.
// instead, hash the current working directory, snaffle the first 16 bytes and append that to dendrite_test
// and use that as the unique db name. We do this because packages are per-directory hence by hashing the
// working (test) directory we ensure we get a consistent hash and don't hash against concurrent packages.
wd, err := os.Getwd()
if err != nil {
t.Fatalf("cannot get working directory: %s", err)
}
hash := sha256.Sum256([]byte(wd))
dbName := fmt.Sprintf("dendrite_test_%s", hex.EncodeToString(hash[:16]))
if postgresDB == "" { // local server, use createdb
createLocalDB(dbName)
} else { // remote server, shell into the postgres user and CREATE DATABASE
createRemoteDB(t, dbName, user, connStr)
}
connStr += fmt.Sprintf(" dbname=%s", dbName)
return connStr, func() {
// Drop all tables on the database to get a fresh instance
db, err := sql.Open("postgres", connStr)
@ -121,6 +163,7 @@ func WithAllDatabases(t *testing.T, testFn func(t *testing.T, db DBType)) {
for dbName, dbType := range dbs {
dbt := dbType
t.Run(dbName, func(tt *testing.T) {
tt.Parallel()
testFn(tt, dbt)
})
}

View file

@ -15,7 +15,9 @@
package test
import (
"bytes"
"crypto/ed25519"
"testing"
"time"
"github.com/matrix-org/gomatrixserverlib"
@ -49,3 +51,40 @@ func WithUnsigned(unsigned interface{}) eventModifier {
e.unsigned = unsigned
}
}
// Reverse a list of events
func Reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
for i := 0; i < len(in); i++ {
out[i] = in[len(in)-i-1]
}
return out
}
func AssertEventIDsEqual(t *testing.T, gotEventIDs []string, wants []*gomatrixserverlib.HeaderedEvent) {
t.Helper()
if len(gotEventIDs) != len(wants) {
t.Fatalf("length mismatch: got %d events, want %d", len(gotEventIDs), len(wants))
}
for i := range wants {
w := wants[i].EventID()
g := gotEventIDs[i]
if w != g {
t.Errorf("event at index %d mismatch:\ngot %s\n\nwant %s", i, string(g), string(w))
}
}
}
func AssertEventsEqual(t *testing.T, gots, wants []*gomatrixserverlib.HeaderedEvent) {
t.Helper()
if len(gots) != len(wants) {
t.Fatalf("length mismatch: got %d events, want %d", len(gots), len(wants))
}
for i := range wants {
w := wants[i].JSON()
g := gots[i].JSON()
if !bytes.Equal(w, g) {
t.Errorf("event at index %d mismatch:\ngot %s\n\nwant %s", i, string(g), string(w))
}
}
}

View file

@ -494,7 +494,7 @@ type PerformPusherDeletionRequest struct {
type Pusher struct {
SessionID int64 `json:"session_id,omitempty"`
PushKey string `json:"pushkey"`
PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"`
PushKeyTS int64 `json:"pushkey_ts,omitempty"`
Kind PusherKind `json:"kind"`
AppID string `json:"app_id"`
AppDisplayName string `json:"app_display_name"`

View file

@ -653,7 +653,7 @@ func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.Perform
return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart)
}
if req.Pusher.PushKeyTS == 0 {
req.Pusher.PushKeyTS = gomatrixserverlib.AsTimestamp(time.Now())
req.Pusher.PushKeyTS = int64(time.Now().Unix())
}
return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart)
}

View file

@ -23,7 +23,6 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
@ -95,7 +94,7 @@ type pushersStatements struct {
// Returns nil error success.
func (s *pushersStatements) InsertPusher(
ctx context.Context, txn *sql.Tx, session_id int64,
pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
) error {
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
logrus.Debugf("Created pusher %d", session_id)

View file

@ -23,7 +23,6 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
@ -95,7 +94,7 @@ type pushersStatements struct {
// Returns nil error success.
func (s *pushersStatements) InsertPusher(
ctx context.Context, txn *sql.Tx, session_id int64,
pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
) error {
_, err := s.insertPusherStmt.ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
logrus.Debugf("Created pusher %d", session_id)

View file

@ -21,7 +21,6 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
type AccountDataTable interface {
@ -96,7 +95,7 @@ type ThreePIDTable interface {
}
type PusherTable interface {
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error
SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error)
DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error
DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error