diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index 608599498..9cc94d650 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -52,6 +52,7 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" + pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeMulticast "github.com/matrix-org/pinecone/multicast" pineconeRouter "github.com/matrix-org/pinecone/router" pineconeSessions "github.com/matrix-org/pinecone/sessions" @@ -71,11 +72,9 @@ type DendriteMonolith struct { PineconeRouter *pineconeRouter.Router PineconeMulticast *pineconeMulticast.Multicast PineconeQUIC *pineconeSessions.Sessions + PineconeManager *pineconeConnections.ConnectionManager StorageDirectory string CacheDirectory string - staticPeerURI string - staticPeerMutex sync.RWMutex - staticPeerAttempt chan struct{} listener net.Listener httpServer *http.Server processContext *process.ProcessContext @@ -104,15 +103,8 @@ func (m *DendriteMonolith) SetMulticastEnabled(enabled bool) { } func (m *DendriteMonolith) SetStaticPeer(uri string) { - m.staticPeerMutex.Lock() - m.staticPeerURI = strings.TrimSpace(uri) - m.staticPeerMutex.Unlock() - m.DisconnectType(int(pineconeRouter.PeerTypeRemote)) - if uri != "" { - go func() { - m.staticPeerAttempt <- struct{}{} - }() - } + m.PineconeManager.RemovePeers() + m.PineconeManager.AddPeer(strings.TrimSpace(uri)) } func (m *DendriteMonolith) DisconnectType(peertype int) { @@ -210,43 +202,6 @@ func (m *DendriteMonolith) RegisterDevice(localpart, deviceID string) (string, e return loginRes.Device.AccessToken, nil } -func (m *DendriteMonolith) staticPeerConnect() { - connected := map[string]bool{} // URI -> connected? - attempt := func() { - m.staticPeerMutex.RLock() - uri := m.staticPeerURI - m.staticPeerMutex.RUnlock() - if uri == "" { - return - } - for k := range connected { - delete(connected, k) - } - for _, uri := range strings.Split(uri, ",") { - connected[strings.TrimSpace(uri)] = false - } - for _, info := range m.PineconeRouter.Peers() { - connected[info.URI] = true - } - for k, online := range connected { - if !online { - if err := conn.ConnectToPeer(m.PineconeRouter, k); err != nil { - logrus.WithError(err).Error("Failed to connect to static peer") - } - } - } - } - for { - select { - case <-m.processContext.Context().Done(): - case <-m.staticPeerAttempt: - attempt() - case <-time.After(time.Second * 5): - attempt() - } - } -} - // nolint:gocyclo func (m *DendriteMonolith) Start() { var err error @@ -284,6 +239,7 @@ func (m *DendriteMonolith) Start() { m.PineconeRouter = pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false) m.PineconeQUIC = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), m.PineconeRouter, []string{"matrix"}) m.PineconeMulticast = pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), m.PineconeRouter) + m.PineconeManager = pineconeConnections.NewConnectionManager(m.PineconeRouter) prefix := hex.EncodeToString(pk) cfg := &config.Dendrite{} @@ -392,9 +348,6 @@ func (m *DendriteMonolith) Start() { m.processContext = base.ProcessContext - m.staticPeerAttempt = make(chan struct{}, 1) - go m.staticPeerConnect() - go func() { m.logger.Info("Listening on ", cfg.Global.ServerName) m.logger.Fatal(m.httpServer.Serve(m.PineconeQUIC.Protocol("matrix"))) diff --git a/clientapi/routing/voip.go b/clientapi/routing/voip.go index 13dca7ac0..c7ddaabcf 100644 --- a/clientapi/routing/voip.go +++ b/clientapi/routing/voip.go @@ -52,6 +52,7 @@ func RequestTurnServer(req *http.Request, device *api.Device, cfg *config.Client if turnConfig.SharedSecret != "" { expiry := time.Now().Add(duration).Unix() + resp.Username = fmt.Sprintf("%d:%s", expiry, device.UserID) mac := hmac.New(sha1.New, []byte(turnConfig.SharedSecret)) _, err := mac.Write([]byte(resp.Username)) @@ -60,7 +61,6 @@ func RequestTurnServer(req *http.Request, device *api.Device, cfg *config.Client return jsonerror.InternalServerError() } - resp.Username = fmt.Sprintf("%d:%s", expiry, device.UserID) resp.Password = base64.StdEncoding.EncodeToString(mac.Sum(nil)) } else if turnConfig.Username != "" && turnConfig.Password != "" { resp.Username = turnConfig.Username diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index a3d3ed175..dd1ab3697 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -25,7 +25,6 @@ import ( "net" "net/http" "os" - "strings" "time" "github.com/gorilla/mux" @@ -47,6 +46,7 @@ import ( "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/gomatrixserverlib" + pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeMulticast "github.com/matrix-org/pinecone/multicast" pineconeRouter "github.com/matrix-org/pinecone/router" pineconeSessions "github.com/matrix-org/pinecone/sessions" @@ -90,6 +90,13 @@ func main() { } pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false) + pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"}) + pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter) + pManager := pineconeConnections.NewConnectionManager(pRouter) + pMulticast.Start() + if instancePeer != nil && *instancePeer != "" { + pManager.AddPeer(*instancePeer) + } go func() { listener, err := net.Listen("tcp", *instanceListen) @@ -119,36 +126,6 @@ func main() { } }() - pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"}) - pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter) - pMulticast.Start() - - connectToStaticPeer := func() { - connected := map[string]bool{} // URI -> connected? - for _, uri := range strings.Split(*instancePeer, ",") { - connected[strings.TrimSpace(uri)] = false - } - attempt := func() { - for k := range connected { - connected[k] = false - } - for _, info := range pRouter.Peers() { - connected[info.URI] = true - } - for k, online := range connected { - if !online { - if err := conn.ConnectToPeer(pRouter, k); err != nil { - logrus.WithError(err).Error("Failed to connect to static peer") - } - } - } - } - for { - attempt() - time.Sleep(time.Second * 5) - } - } - cfg := &config.Dendrite{} cfg.Defaults(true) cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) @@ -268,7 +245,6 @@ func main() { Handler: pMux, } - go connectToStaticPeer() go func() { pubkey := pRouter.PublicKey() logrus.Info("Listening on ", hex.EncodeToString(pubkey[:])) diff --git a/cmd/dendritejs-pinecone/main.go b/cmd/dendritejs-pinecone/main.go index ba9edf230..211b3e131 100644 --- a/cmd/dendritejs-pinecone/main.go +++ b/cmd/dendritejs-pinecone/main.go @@ -22,7 +22,6 @@ import ( "encoding/hex" "fmt" "syscall/js" - "time" "github.com/gorilla/mux" "github.com/matrix-org/dendrite/appservice" @@ -44,6 +43,7 @@ import ( _ "github.com/matrix-org/go-sqlite3-js" + pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeRouter "github.com/matrix-org/pinecone/router" pineconeSessions "github.com/matrix-org/pinecone/sessions" ) @@ -154,6 +154,8 @@ func startup() { pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false) pSessions := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"}) + pManager := pineconeConnections.NewConnectionManager(pRouter) + pManager.AddPeer("wss://pinecone.matrix.org/public") cfg := &config.Dendrite{} cfg.Defaults(true) @@ -237,20 +239,4 @@ func startup() { } s.ListenAndServe("fetch") }() - - // Connect to the static peer - go func() { - for { - if pRouter.PeerCount(pineconeRouter.PeerTypeRemote) == 0 { - if err := conn.ConnectToPeer(pRouter, publicPeer); err != nil { - logrus.WithError(err).Error("Failed to connect to static peer") - } - } - select { - case <-base.ProcessContext.Context().Done(): - return - case <-time.After(time.Second * 5): - } - } - }() } diff --git a/go.mod b/go.mod index 0e433b137..ba222ed8f 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,8 @@ module github.com/matrix-org/dendrite -replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.7.5-0.20220311134712-e2e4a244f30e +replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.8.1-0.20220419100629-2278c94774f9 -replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c +replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.13.1-0.20220419101051-b262d9f0be1e require ( github.com/Arceliar/ironwood v0.0.0-20211125050254-8951369625d0 @@ -27,15 +27,15 @@ require ( github.com/juju/testing v0.0.0-20220203020004-a0ff61f03494 // indirect github.com/lib/pq v1.10.5 github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e - github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d + github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5 - github.com/matrix-org/pinecone v0.0.0-20220404141326-e526fa82f79d + github.com/matrix-org/gomatrixserverlib v0.0.0-20220408160933-cf558306b56f + github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 github.com/miekg/dns v1.1.31 // indirect github.com/nats-io/nats-server/v2 v2.7.4-0.20220309205833-773636c1c5bb - github.com/nats-io/nats.go v1.13.1-0.20220308171302-2f2f6968e98d + github.com/nats-io/nats.go v1.14.0 github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/ngrok/sqlmw v0.0.0-20211220175533-9d16fdc47b31 @@ -56,6 +56,7 @@ require ( golang.org/x/image v0.0.0-20220321031419-a8550c1d254a golang.org/x/mobile v0.0.0-20220407111146-e579adbbc4a2 golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3 + golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 // indirect golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/yaml.v2 v2.4.0 diff --git a/go.sum b/go.sum index 959a1a5a2..8bb306a82 100644 --- a/go.sum +++ b/go.sum @@ -789,15 +789,15 @@ github.com/masterzen/winrm v0.0.0-20161014151040-7a535cd943fc/go.mod h1:CfZSN7zw github.com/masterzen/xmlpath v0.0.0-20140218185901-13f4951698ad/go.mod h1:A0zPC53iKKKcXYxr4ROjpQRQ5FgJXtelNdSmHHuq/tY= github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e h1:DP5RC0Z3XdyBEW5dKt8YPeN6vZbm6OzVaGVp7f1BQRM= github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e/go.mod h1:NgPCr+UavRGH6n5jmdX8DuqFZ4JiCWIJoZiuhTRLSUg= -github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d h1:mGhPVaTht5NViFN/UpdrIlRApmH2FWcVaKUH5MdBKiY= -github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= +github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw2QV3YD/fRrzEDPNGgTlJlvXY0EHHnT87wF3OA= +github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5 h1:Fkennny7+Z/5pygrhjFMZbz1j++P2hhhLoT7NO3p8DQ= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5/go.mod h1:V5eO8rn/C3rcxig37A/BCeKerLFS+9Avg/77FIeTZ48= -github.com/matrix-org/pinecone v0.0.0-20220404141326-e526fa82f79d h1:1+T4eOPRsf6cr0lMPW4oO2k8TTHm4mqIh65kpEID5Rk= -github.com/matrix-org/pinecone v0.0.0-20220404141326-e526fa82f79d/go.mod h1:ulJzsVOTssIVp1j/m5eI//4VpAGDkMt5NrRuAVX7wpc= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220408160933-cf558306b56f h1:MZrl4TgTnlaOn2Cu9gJCoJ3oyW5mT4/3QIZGgZXzKl4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220408160933-cf558306b56f/go.mod h1:V5eO8rn/C3rcxig37A/BCeKerLFS+9Avg/77FIeTZ48= +github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48 h1:W0sjjC6yjskHX4mb0nk3p0fXAlbU5bAFUFeEtlrPASE= +github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48/go.mod h1:ulJzsVOTssIVp1j/m5eI//4VpAGDkMt5NrRuAVX7wpc= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= @@ -878,8 +878,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8m github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= -github.com/nats-io/jwt/v2 v2.2.1-0.20220113022732-58e87895b296 h1:vU9tpM3apjYlLLeY23zRWJ9Zktr5jp+mloR942LEOpY= -github.com/nats-io/jwt/v2 v2.2.1-0.20220113022732-58e87895b296/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= +github.com/nats-io/jwt/v2 v2.2.1-0.20220330180145-442af02fd36a h1:lem6QCvxR0Y28gth9P+wV2K/zYUUAkJ+55U8cpS0p5I= +github.com/nats-io/jwt/v2 v2.2.1-0.20220330180145-442af02fd36a/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= @@ -888,10 +888,10 @@ github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uY github.com/ncw/swift v1.0.47/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= -github.com/neilalexander/nats-server/v2 v2.7.5-0.20220311134712-e2e4a244f30e h1:5tEHLzvDeS6IeqO2o9FFhsE3V2erYj8FlMt2J91wzsk= -github.com/neilalexander/nats-server/v2 v2.7.5-0.20220311134712-e2e4a244f30e/go.mod h1:1vZ2Nijh8tcyNe8BDVyTviCd9NYzRbubQYiEHsvOQWc= -github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c h1:G2qsv7D0rY94HAu8pXmElMluuMHQ85waxIDQBhIzV2Q= -github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w= +github.com/neilalexander/nats-server/v2 v2.8.1-0.20220419100629-2278c94774f9 h1:VGU5HYAwy8LRbSkrT+kCHvujVmwK8Aa/vc1O+eReTbM= +github.com/neilalexander/nats-server/v2 v2.8.1-0.20220419100629-2278c94774f9/go.mod h1:5vic7C58BFEVltiZhs7Kq81q2WcEPhJPsmNv1FOrdv0= +github.com/neilalexander/nats.go v1.13.1-0.20220419101051-b262d9f0be1e h1:kNIzIzj2OvnlreA+sTJ12nWJzTP3OSLNKDL/Iq9mF6Y= +github.com/neilalexander/nats.go v1.13.1-0.20220419101051-b262d9f0be1e/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w= github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 h1:lrVQzBtkeQEGGYUHwSX1XPe1E5GL6U3KYCNe2G4bncQ= github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= @@ -1541,8 +1541,9 @@ golang.org/x/sys v0.0.0-20211116061358-0a5406a5449c/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64 h1:D1v9ucDTYBtbz5vNuBbAhIMAGhQhJ6Ym5ah3maMVNX4= golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 h1:QyVthZKMsyaQwBTJE04jdNN0Pp5Fn9Qga0mrgxyERQM= +golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= diff --git a/internal/caching/cache_lazy_load_members.go b/internal/caching/cache_lazy_load_members.go new file mode 100644 index 000000000..71a317624 --- /dev/null +++ b/internal/caching/cache_lazy_load_members.go @@ -0,0 +1,86 @@ +package caching + +import ( + "fmt" + "time" + + userapi "github.com/matrix-org/dendrite/userapi/api" +) + +const ( + LazyLoadCacheName = "lazy_load_members" + LazyLoadCacheMaxEntries = 128 + LazyLoadCacheMaxUserEntries = 128 + LazyLoadCacheMutable = true + LazyLoadCacheMaxAge = time.Minute * 30 +) + +type LazyLoadCache struct { + // InMemoryLRUCachePartition containing other InMemoryLRUCachePartitions + // with the actual cached members + userCaches *InMemoryLRUCachePartition +} + +// NewLazyLoadCache creates a new LazyLoadCache. +func NewLazyLoadCache() (*LazyLoadCache, error) { + cache, err := NewInMemoryLRUCachePartition( + LazyLoadCacheName, + LazyLoadCacheMutable, + LazyLoadCacheMaxEntries, + LazyLoadCacheMaxAge, + true, + ) + if err != nil { + return nil, err + } + go cacheCleaner(cache) + return &LazyLoadCache{ + userCaches: cache, + }, nil +} + +func (c *LazyLoadCache) lazyLoadCacheForUser(device *userapi.Device) (*InMemoryLRUCachePartition, error) { + cacheName := fmt.Sprintf("%s/%s", device.UserID, device.ID) + userCache, ok := c.userCaches.Get(cacheName) + if ok && userCache != nil { + if cache, ok := userCache.(*InMemoryLRUCachePartition); ok { + return cache, nil + } + } + cache, err := NewInMemoryLRUCachePartition( + LazyLoadCacheName, + LazyLoadCacheMutable, + LazyLoadCacheMaxUserEntries, + LazyLoadCacheMaxAge, + false, + ) + if err != nil { + return nil, err + } + c.userCaches.Set(cacheName, cache) + go cacheCleaner(cache) + return cache, nil +} + +func (c *LazyLoadCache) StoreLazyLoadedUser(device *userapi.Device, roomID, userID, eventID string) { + cache, err := c.lazyLoadCacheForUser(device) + if err != nil { + return + } + cacheKey := fmt.Sprintf("%s/%s/%s/%s", device.UserID, device.ID, roomID, userID) + cache.Set(cacheKey, eventID) +} + +func (c *LazyLoadCache) IsLazyLoadedUserCached(device *userapi.Device, roomID, userID string) (string, bool) { + cache, err := c.lazyLoadCacheForUser(device) + if err != nil { + return "", false + } + + cacheKey := fmt.Sprintf("%s/%s/%s/%s", device.UserID, device.ID, roomID, userID) + val, ok := cache.Get(cacheKey) + if !ok { + return "", ok + } + return val.(string), ok +} diff --git a/internal/pushgateway/pushgateway.go b/internal/pushgateway/pushgateway.go index 88c326eb2..1817a040b 100644 --- a/internal/pushgateway/pushgateway.go +++ b/internal/pushgateway/pushgateway.go @@ -3,8 +3,6 @@ package pushgateway import ( "context" "encoding/json" - - "github.com/matrix-org/gomatrixserverlib" ) // A Client is how interactions with a Push Gateway is done. @@ -47,11 +45,11 @@ type Counts struct { } type Device struct { - AppID string `json:"app_id"` // Required - Data map[string]interface{} `json:"data"` // Required. UNSPEC: Sytests require this to allow unknown keys. - PushKey string `json:"pushkey"` // Required - PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"` - Tweaks map[string]interface{} `json:"tweaks,omitempty"` + AppID string `json:"app_id"` // Required + Data map[string]interface{} `json:"data"` // Required. UNSPEC: Sytests require this to allow unknown keys. + PushKey string `json:"pushkey"` // Required + PushKeyTS int64 `json:"pushkey_ts,omitempty"` + Tweaks map[string]interface{} `json:"tweaks,omitempty"` } type Prio string diff --git a/mediaapi/mediaapi.go b/mediaapi/mediaapi.go index c010981c0..e5daf480d 100644 --- a/mediaapi/mediaapi.go +++ b/mediaapi/mediaapi.go @@ -32,7 +32,7 @@ func AddPublicRoutes( userAPI userapi.UserInternalAPI, client *gomatrixserverlib.Client, ) { - mediaDB, err := storage.Open(&cfg.Database) + mediaDB, err := storage.NewMediaAPIDatasource(&cfg.Database) if err != nil { logrus.WithError(err).Panicf("failed to connect to media db") } diff --git a/mediaapi/routing/upload.go b/mediaapi/routing/upload.go index f762b2ff5..972c52af0 100644 --- a/mediaapi/routing/upload.go +++ b/mediaapi/routing/upload.go @@ -22,6 +22,7 @@ import ( "io" "net/http" "net/url" + "os" "path" "strings" @@ -311,6 +312,26 @@ func (r *uploadRequest) storeFileAndMetadata( } go func() { + file, err := os.Open(string(finalPath)) + if err != nil { + r.Logger.WithError(err).Error("unable to open file") + return + } + defer file.Close() // nolint: errcheck + // http.DetectContentType only needs 512 bytes + buf := make([]byte, 512) + _, err = file.Read(buf) + if err != nil { + r.Logger.WithError(err).Error("unable to read file") + return + } + // Check if we need to generate thumbnails + fileType := http.DetectContentType(buf) + if !strings.HasPrefix(fileType, "image") { + r.Logger.WithField("contentType", fileType).Debugf("uploaded file is not an image or can not be thumbnailed, not generating thumbnails") + return + } + busy, err := thumbnailer.GenerateThumbnails( context.Background(), finalPath, thumbnailSizes, r.MediaMetadata, activeThumbnailGeneration, maxThumbnailGenerators, db, r.Logger, diff --git a/mediaapi/routing/upload_test.go b/mediaapi/routing/upload_test.go index e81254f35..b2c2f5a44 100644 --- a/mediaapi/routing/upload_test.go +++ b/mediaapi/routing/upload_test.go @@ -51,7 +51,7 @@ func Test_uploadRequest_doUpload(t *testing.T) { _ = os.Mkdir(testdataPath, os.ModePerm) defer fileutils.RemoveDir(types.Path(testdataPath), nil) - db, err := storage.Open(&config.DatabaseOptions{ + db, err := storage.NewMediaAPIDatasource(&config.DatabaseOptions{ ConnectionString: "file::memory:?cache=shared", MaxOpenConnections: 100, MaxIdleConnections: 2, diff --git a/mediaapi/storage/interface.go b/mediaapi/storage/interface.go index 843199719..d083be1eb 100644 --- a/mediaapi/storage/interface.go +++ b/mediaapi/storage/interface.go @@ -22,9 +22,17 @@ import ( ) type Database interface { + MediaRepository + Thumbnails +} + +type MediaRepository interface { StoreMediaMetadata(ctx context.Context, mediaMetadata *types.MediaMetadata) error GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) +} + +type Thumbnails interface { StoreThumbnail(ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata) error GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error) GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) ([]*types.ThumbnailMetadata, error) diff --git a/mediaapi/storage/postgres/media_repository_table.go b/mediaapi/storage/postgres/media_repository_table.go index 1d3264ca9..41cee4878 100644 --- a/mediaapi/storage/postgres/media_repository_table.go +++ b/mediaapi/storage/postgres/media_repository_table.go @@ -20,6 +20,8 @@ import ( "database/sql" "time" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/mediaapi/storage/tables" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -69,24 +71,25 @@ type mediaStatements struct { selectMediaByHashStmt *sql.Stmt } -func (s *mediaStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(mediaSchema) +func NewPostgresMediaRepositoryTable(db *sql.DB) (tables.MediaRepository, error) { + s := &mediaStatements{} + _, err := db.Exec(mediaSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, sqlutil.StatementList{ {&s.insertMediaStmt, insertMediaSQL}, {&s.selectMediaStmt, selectMediaSQL}, {&s.selectMediaByHashStmt, selectMediaByHashSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *mediaStatements) insertMedia( - ctx context.Context, mediaMetadata *types.MediaMetadata, +func (s *mediaStatements) InsertMedia( + ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata, ) error { - mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000) - _, err := s.insertMediaStmt.ExecContext( + mediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now()) + _, err := sqlutil.TxStmtContext(ctx, txn, s.insertMediaStmt).ExecContext( ctx, mediaMetadata.MediaID, mediaMetadata.Origin, @@ -100,14 +103,14 @@ func (s *mediaStatements) insertMedia( return err } -func (s *mediaStatements) selectMedia( - ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, +func (s *mediaStatements) SelectMedia( + ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, ) (*types.MediaMetadata, error) { mediaMetadata := types.MediaMetadata{ MediaID: mediaID, Origin: mediaOrigin, } - err := s.selectMediaStmt.QueryRowContext( + err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaStmt).QueryRowContext( ctx, mediaMetadata.MediaID, mediaMetadata.Origin, ).Scan( &mediaMetadata.ContentType, @@ -120,14 +123,14 @@ func (s *mediaStatements) selectMedia( return &mediaMetadata, err } -func (s *mediaStatements) selectMediaByHash( - ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, +func (s *mediaStatements) SelectMediaByHash( + ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, ) (*types.MediaMetadata, error) { mediaMetadata := types.MediaMetadata{ Base64Hash: mediaHash, Origin: mediaOrigin, } - err := s.selectMediaStmt.QueryRowContext( + err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaByHashStmt).QueryRowContext( ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin, ).Scan( &mediaMetadata.ContentType, diff --git a/mediaapi/storage/postgres/mediaapi.go b/mediaapi/storage/postgres/mediaapi.go new file mode 100644 index 000000000..ea70e575b --- /dev/null +++ b/mediaapi/storage/postgres/mediaapi.go @@ -0,0 +1,46 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + // Import the postgres database driver. + _ "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/mediaapi/storage/shared" + "github.com/matrix-org/dendrite/setup/config" +) + +// NewDatabase opens a postgres database. +func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) { + db, err := sqlutil.Open(dbProperties) + if err != nil { + return nil, err + } + mediaRepo, err := NewPostgresMediaRepositoryTable(db) + if err != nil { + return nil, err + } + thumbnails, err := NewPostgresThumbnailsTable(db) + if err != nil { + return nil, err + } + return &shared.Database{ + MediaRepository: mediaRepo, + Thumbnails: thumbnails, + DB: db, + Writer: sqlutil.NewExclusiveWriter(), + }, nil +} diff --git a/mediaapi/storage/postgres/prepare.go b/mediaapi/storage/postgres/prepare.go deleted file mode 100644 index a2e01884e..000000000 --- a/mediaapi/storage/postgres/prepare.go +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// FIXME: This should be made internal! - -package postgres - -import ( - "database/sql" -) - -// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement. -type statementList []struct { - statement **sql.Stmt - sql string -} - -// prepare the SQL for each statement in the list and assign the result to the prepared statement. -func (s statementList) prepare(db *sql.DB) (err error) { - for _, statement := range s { - if *statement.statement, err = db.Prepare(statement.sql); err != nil { - return - } - } - return -} diff --git a/mediaapi/storage/postgres/sql.go b/mediaapi/storage/postgres/sql.go deleted file mode 100644 index 181cd15ff..000000000 --- a/mediaapi/storage/postgres/sql.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "database/sql" -) - -type statements struct { - media mediaStatements - thumbnail thumbnailStatements -} - -func (s *statements) prepare(db *sql.DB) (err error) { - if err = s.media.prepare(db); err != nil { - return - } - if err = s.thumbnail.prepare(db); err != nil { - return - } - - return -} diff --git a/mediaapi/storage/postgres/thumbnail_table.go b/mediaapi/storage/postgres/thumbnail_table.go index 3f28cdbbf..7e07b476e 100644 --- a/mediaapi/storage/postgres/thumbnail_table.go +++ b/mediaapi/storage/postgres/thumbnail_table.go @@ -21,6 +21,8 @@ import ( "time" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/mediaapi/storage/tables" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -63,7 +65,7 @@ SELECT content_type, file_size_bytes, creation_ts FROM mediaapi_thumbnail WHERE // Note: this selects all thumbnails for a media_origin and media_id const selectThumbnailsSQL = ` -SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 +SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 ORDER BY creation_ts ASC ` type thumbnailStatements struct { @@ -72,24 +74,25 @@ type thumbnailStatements struct { selectThumbnailsStmt *sql.Stmt } -func (s *thumbnailStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(thumbnailSchema) +func NewPostgresThumbnailsTable(db *sql.DB) (tables.Thumbnails, error) { + s := &thumbnailStatements{} + _, err := db.Exec(thumbnailSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, sqlutil.StatementList{ {&s.insertThumbnailStmt, insertThumbnailSQL}, {&s.selectThumbnailStmt, selectThumbnailSQL}, {&s.selectThumbnailsStmt, selectThumbnailsSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *thumbnailStatements) insertThumbnail( - ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata, +func (s *thumbnailStatements) InsertThumbnail( + ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata, ) error { - thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000) - _, err := s.insertThumbnailStmt.ExecContext( + thumbnailMetadata.MediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now()) + _, err := sqlutil.TxStmtContext(ctx, txn, s.insertThumbnailStmt).ExecContext( ctx, thumbnailMetadata.MediaMetadata.MediaID, thumbnailMetadata.MediaMetadata.Origin, @@ -103,8 +106,9 @@ func (s *thumbnailStatements) insertThumbnail( return err } -func (s *thumbnailStatements) selectThumbnail( +func (s *thumbnailStatements) SelectThumbnail( ctx context.Context, + txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, @@ -121,7 +125,7 @@ func (s *thumbnailStatements) selectThumbnail( ResizeMethod: resizeMethod, }, } - err := s.selectThumbnailStmt.QueryRowContext( + err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailStmt).QueryRowContext( ctx, thumbnailMetadata.MediaMetadata.MediaID, thumbnailMetadata.MediaMetadata.Origin, @@ -136,10 +140,10 @@ func (s *thumbnailStatements) selectThumbnail( return &thumbnailMetadata, err } -func (s *thumbnailStatements) selectThumbnails( - ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, +func (s *thumbnailStatements) SelectThumbnails( + ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, ) ([]*types.ThumbnailMetadata, error) { - rows, err := s.selectThumbnailsStmt.QueryContext( + rows, err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailsStmt).QueryContext( ctx, mediaID, mediaOrigin, ) if err != nil { diff --git a/mediaapi/storage/postgres/storage.go b/mediaapi/storage/shared/mediaapi.go similarity index 52% rename from mediaapi/storage/postgres/storage.go rename to mediaapi/storage/shared/mediaapi.go index 61ad468fe..c8d9ad6ab 100644 --- a/mediaapi/storage/postgres/storage.go +++ b/mediaapi/storage/shared/mediaapi.go @@ -1,5 +1,4 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// Copyright 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,54 +12,38 @@ // See the License for the specific language governing permissions and // limitations under the License. -package postgres +package shared import ( "context" "database/sql" - // Import the postgres database driver. - _ "github.com/lib/pq" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/mediaapi/storage/tables" "github.com/matrix-org/dendrite/mediaapi/types" - "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" ) -// Database is used to store metadata about a repository of media files. type Database struct { - statements statements - db *sql.DB -} - -// Open opens a postgres database. -func Open(dbProperties *config.DatabaseOptions) (*Database, error) { - var d Database - var err error - if d.db, err = sqlutil.Open(dbProperties); err != nil { - return nil, err - } - if err = d.statements.prepare(d.db); err != nil { - return nil, err - } - return &d, nil + DB *sql.DB + Writer sqlutil.Writer + MediaRepository tables.MediaRepository + Thumbnails tables.Thumbnails } // StoreMediaMetadata inserts the metadata about the uploaded media into the database. // Returns an error if the combination of MediaID and Origin are not unique in the table. -func (d *Database) StoreMediaMetadata( - ctx context.Context, mediaMetadata *types.MediaMetadata, -) error { - return d.statements.media.insertMedia(ctx, mediaMetadata) +func (d Database) StoreMediaMetadata(ctx context.Context, mediaMetadata *types.MediaMetadata) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.MediaRepository.InsertMedia(ctx, txn, mediaMetadata) + }) } // GetMediaMetadata returns metadata about media stored on this server. // The media could have been uploaded to this server or fetched from another server and cached here. // Returns nil metadata if there is no metadata associated with this media. -func (d *Database) GetMediaMetadata( - ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, -) (*types.MediaMetadata, error) { - mediaMetadata, err := d.statements.media.selectMedia(ctx, mediaID, mediaOrigin) +func (d Database) GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) { + mediaMetadata, err := d.MediaRepository.SelectMedia(ctx, nil, mediaID, mediaOrigin) if err != nil && err == sql.ErrNoRows { return nil, nil } @@ -70,10 +53,8 @@ func (d *Database) GetMediaMetadata( // GetMediaMetadataByHash returns metadata about media stored on this server. // The media could have been uploaded to this server or fetched from another server and cached here. // Returns nil metadata if there is no metadata associated with this media. -func (d *Database) GetMediaMetadataByHash( - ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, -) (*types.MediaMetadata, error) { - mediaMetadata, err := d.statements.media.selectMediaByHash(ctx, mediaHash, mediaOrigin) +func (d Database) GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) { + mediaMetadata, err := d.MediaRepository.SelectMediaByHash(ctx, nil, mediaHash, mediaOrigin) if err != nil && err == sql.ErrNoRows { return nil, nil } @@ -82,40 +63,36 @@ func (d *Database) GetMediaMetadataByHash( // StoreThumbnail inserts the metadata about the thumbnail into the database. // Returns an error if the combination of MediaID and Origin are not unique in the table. -func (d *Database) StoreThumbnail( - ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata, -) error { - return d.statements.thumbnail.insertThumbnail(ctx, thumbnailMetadata) +func (d Database) StoreThumbnail(ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Thumbnails.InsertThumbnail(ctx, txn, thumbnailMetadata) + }) } // GetThumbnail returns metadata about a specific thumbnail. // The media could have been uploaded to this server or fetched from another server and cached here. // Returns nil metadata if there is no metadata associated with this thumbnail. -func (d *Database) GetThumbnail( - ctx context.Context, - mediaID types.MediaID, - mediaOrigin gomatrixserverlib.ServerName, - width, height int, - resizeMethod string, -) (*types.ThumbnailMetadata, error) { - thumbnailMetadata, err := d.statements.thumbnail.selectThumbnail( - ctx, mediaID, mediaOrigin, width, height, resizeMethod, - ) - if err != nil && err == sql.ErrNoRows { - return nil, nil +func (d Database) GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error) { + metadata, err := d.Thumbnails.SelectThumbnail(ctx, nil, mediaID, mediaOrigin, width, height, resizeMethod) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err } - return thumbnailMetadata, err + return metadata, err } // GetThumbnails returns metadata about all thumbnails for a specific media stored on this server. // The media could have been uploaded to this server or fetched from another server and cached here. // Returns nil metadata if there are no thumbnails associated with this media. -func (d *Database) GetThumbnails( - ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, -) ([]*types.ThumbnailMetadata, error) { - thumbnails, err := d.statements.thumbnail.selectThumbnails(ctx, mediaID, mediaOrigin) - if err != nil && err == sql.ErrNoRows { - return nil, nil +func (d Database) GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) ([]*types.ThumbnailMetadata, error) { + metadatas, err := d.Thumbnails.SelectThumbnails(ctx, nil, mediaID, mediaOrigin) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err } - return thumbnails, err + return metadatas, err } diff --git a/mediaapi/storage/sqlite3/media_repository_table.go b/mediaapi/storage/sqlite3/media_repository_table.go index bcef609d8..78431967f 100644 --- a/mediaapi/storage/sqlite3/media_repository_table.go +++ b/mediaapi/storage/sqlite3/media_repository_table.go @@ -21,6 +21,7 @@ import ( "time" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/mediaapi/storage/tables" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -66,57 +67,53 @@ SELECT content_type, file_size_bytes, creation_ts, upload_name, media_id, user_i type mediaStatements struct { db *sql.DB - writer sqlutil.Writer insertMediaStmt *sql.Stmt selectMediaStmt *sql.Stmt selectMediaByHashStmt *sql.Stmt } -func (s *mediaStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { - s.db = db - s.writer = writer - - _, err = db.Exec(mediaSchema) +func NewSQLiteMediaRepositoryTable(db *sql.DB) (tables.MediaRepository, error) { + s := &mediaStatements{ + db: db, + } + _, err := db.Exec(mediaSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, sqlutil.StatementList{ {&s.insertMediaStmt, insertMediaSQL}, {&s.selectMediaStmt, selectMediaSQL}, {&s.selectMediaByHashStmt, selectMediaByHashSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *mediaStatements) insertMedia( - ctx context.Context, mediaMetadata *types.MediaMetadata, +func (s *mediaStatements) InsertMedia( + ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata, ) error { - mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000) - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertMediaStmt) - _, err := stmt.ExecContext( - ctx, - mediaMetadata.MediaID, - mediaMetadata.Origin, - mediaMetadata.ContentType, - mediaMetadata.FileSizeBytes, - mediaMetadata.CreationTimestamp, - mediaMetadata.UploadName, - mediaMetadata.Base64Hash, - mediaMetadata.UserID, - ) - return err - }) + mediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now()) + _, err := sqlutil.TxStmtContext(ctx, txn, s.insertMediaStmt).ExecContext( + ctx, + mediaMetadata.MediaID, + mediaMetadata.Origin, + mediaMetadata.ContentType, + mediaMetadata.FileSizeBytes, + mediaMetadata.CreationTimestamp, + mediaMetadata.UploadName, + mediaMetadata.Base64Hash, + mediaMetadata.UserID, + ) + return err } -func (s *mediaStatements) selectMedia( - ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, +func (s *mediaStatements) SelectMedia( + ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, ) (*types.MediaMetadata, error) { mediaMetadata := types.MediaMetadata{ MediaID: mediaID, Origin: mediaOrigin, } - err := s.selectMediaStmt.QueryRowContext( + err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaStmt).QueryRowContext( ctx, mediaMetadata.MediaID, mediaMetadata.Origin, ).Scan( &mediaMetadata.ContentType, @@ -129,14 +126,14 @@ func (s *mediaStatements) selectMedia( return &mediaMetadata, err } -func (s *mediaStatements) selectMediaByHash( - ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, +func (s *mediaStatements) SelectMediaByHash( + ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, ) (*types.MediaMetadata, error) { mediaMetadata := types.MediaMetadata{ Base64Hash: mediaHash, Origin: mediaOrigin, } - err := s.selectMediaStmt.QueryRowContext( + err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaByHashStmt).QueryRowContext( ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin, ).Scan( &mediaMetadata.ContentType, diff --git a/mediaapi/storage/sqlite3/sql.go b/mediaapi/storage/sqlite3/mediaapi.go similarity index 51% rename from mediaapi/storage/sqlite3/sql.go rename to mediaapi/storage/sqlite3/mediaapi.go index 245bd40cc..abf329367 100644 --- a/mediaapi/storage/sqlite3/sql.go +++ b/mediaapi/storage/sqlite3/mediaapi.go @@ -16,23 +16,30 @@ package sqlite3 import ( - "database/sql" - + // Import the postgres database driver. "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/mediaapi/storage/shared" + "github.com/matrix-org/dendrite/setup/config" ) -type statements struct { - media mediaStatements - thumbnail thumbnailStatements -} - -func (s *statements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { - if err = s.media.prepare(db, writer); err != nil { - return +// NewDatabase opens a SQLIte database. +func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) { + db, err := sqlutil.Open(dbProperties) + if err != nil { + return nil, err } - if err = s.thumbnail.prepare(db, writer); err != nil { - return + mediaRepo, err := NewSQLiteMediaRepositoryTable(db) + if err != nil { + return nil, err } - - return + thumbnails, err := NewSQLiteThumbnailsTable(db) + if err != nil { + return nil, err + } + return &shared.Database{ + MediaRepository: mediaRepo, + Thumbnails: thumbnails, + DB: db, + Writer: sqlutil.NewExclusiveWriter(), + }, nil } diff --git a/mediaapi/storage/sqlite3/prepare.go b/mediaapi/storage/sqlite3/prepare.go deleted file mode 100644 index 8fb3b56f3..000000000 --- a/mediaapi/storage/sqlite3/prepare.go +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// FIXME: This should be made internal! - -package sqlite3 - -import ( - "database/sql" -) - -// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement. -type statementList []struct { - statement **sql.Stmt - sql string -} - -// prepare the SQL for each statement in the list and assign the result to the prepared statement. -func (s statementList) prepare(db *sql.DB) (err error) { - for _, statement := range s { - if *statement.statement, err = db.Prepare(statement.sql); err != nil { - return - } - } - return -} diff --git a/mediaapi/storage/sqlite3/storage.go b/mediaapi/storage/sqlite3/storage.go deleted file mode 100644 index fa442173b..000000000 --- a/mediaapi/storage/sqlite3/storage.go +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "context" - "database/sql" - - // Import the postgres database driver. - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/mediaapi/types" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" -) - -// Database is used to store metadata about a repository of media files. -type Database struct { - statements statements - db *sql.DB - writer sqlutil.Writer -} - -// Open opens a postgres database. -func Open(dbProperties *config.DatabaseOptions) (*Database, error) { - d := Database{ - writer: sqlutil.NewExclusiveWriter(), - } - var err error - if d.db, err = sqlutil.Open(dbProperties); err != nil { - return nil, err - } - if err = d.statements.prepare(d.db, d.writer); err != nil { - return nil, err - } - return &d, nil -} - -// StoreMediaMetadata inserts the metadata about the uploaded media into the database. -// Returns an error if the combination of MediaID and Origin are not unique in the table. -func (d *Database) StoreMediaMetadata( - ctx context.Context, mediaMetadata *types.MediaMetadata, -) error { - return d.statements.media.insertMedia(ctx, mediaMetadata) -} - -// GetMediaMetadata returns metadata about media stored on this server. -// The media could have been uploaded to this server or fetched from another server and cached here. -// Returns nil metadata if there is no metadata associated with this media. -func (d *Database) GetMediaMetadata( - ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, -) (*types.MediaMetadata, error) { - mediaMetadata, err := d.statements.media.selectMedia(ctx, mediaID, mediaOrigin) - if err != nil && err == sql.ErrNoRows { - return nil, nil - } - return mediaMetadata, err -} - -// GetMediaMetadataByHash returns metadata about media stored on this server. -// The media could have been uploaded to this server or fetched from another server and cached here. -// Returns nil metadata if there is no metadata associated with this media. -func (d *Database) GetMediaMetadataByHash( - ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, -) (*types.MediaMetadata, error) { - mediaMetadata, err := d.statements.media.selectMediaByHash(ctx, mediaHash, mediaOrigin) - if err != nil && err == sql.ErrNoRows { - return nil, nil - } - return mediaMetadata, err -} - -// StoreThumbnail inserts the metadata about the thumbnail into the database. -// Returns an error if the combination of MediaID and Origin are not unique in the table. -func (d *Database) StoreThumbnail( - ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata, -) error { - return d.statements.thumbnail.insertThumbnail(ctx, thumbnailMetadata) -} - -// GetThumbnail returns metadata about a specific thumbnail. -// The media could have been uploaded to this server or fetched from another server and cached here. -// Returns nil metadata if there is no metadata associated with this thumbnail. -func (d *Database) GetThumbnail( - ctx context.Context, - mediaID types.MediaID, - mediaOrigin gomatrixserverlib.ServerName, - width, height int, - resizeMethod string, -) (*types.ThumbnailMetadata, error) { - thumbnailMetadata, err := d.statements.thumbnail.selectThumbnail( - ctx, mediaID, mediaOrigin, width, height, resizeMethod, - ) - if err != nil && err == sql.ErrNoRows { - return nil, nil - } - return thumbnailMetadata, err -} - -// GetThumbnails returns metadata about all thumbnails for a specific media stored on this server. -// The media could have been uploaded to this server or fetched from another server and cached here. -// Returns nil metadata if there are no thumbnails associated with this media. -func (d *Database) GetThumbnails( - ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, -) ([]*types.ThumbnailMetadata, error) { - thumbnails, err := d.statements.thumbnail.selectThumbnails(ctx, mediaID, mediaOrigin) - if err != nil && err == sql.ErrNoRows { - return nil, nil - } - return thumbnails, err -} diff --git a/mediaapi/storage/sqlite3/thumbnail_table.go b/mediaapi/storage/sqlite3/thumbnail_table.go index 06b056b6e..5ff2fece0 100644 --- a/mediaapi/storage/sqlite3/thumbnail_table.go +++ b/mediaapi/storage/sqlite3/thumbnail_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/mediaapi/storage/tables" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -54,55 +55,48 @@ SELECT content_type, file_size_bytes, creation_ts FROM mediaapi_thumbnail WHERE // Note: this selects all thumbnails for a media_origin and media_id const selectThumbnailsSQL = ` -SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 +SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 ORDER BY creation_ts ASC ` type thumbnailStatements struct { - db *sql.DB - writer sqlutil.Writer insertThumbnailStmt *sql.Stmt selectThumbnailStmt *sql.Stmt selectThumbnailsStmt *sql.Stmt } -func (s *thumbnailStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { - _, err = db.Exec(thumbnailSchema) +func NewSQLiteThumbnailsTable(db *sql.DB) (tables.Thumbnails, error) { + s := &thumbnailStatements{} + _, err := db.Exec(thumbnailSchema) if err != nil { - return + return nil, err } - s.db = db - s.writer = writer - return statementList{ + return s, sqlutil.StatementList{ {&s.insertThumbnailStmt, insertThumbnailSQL}, {&s.selectThumbnailStmt, selectThumbnailSQL}, {&s.selectThumbnailsStmt, selectThumbnailsSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *thumbnailStatements) insertThumbnail( - ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata, -) error { - thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000) - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertThumbnailStmt) - _, err := stmt.ExecContext( - ctx, - thumbnailMetadata.MediaMetadata.MediaID, - thumbnailMetadata.MediaMetadata.Origin, - thumbnailMetadata.MediaMetadata.ContentType, - thumbnailMetadata.MediaMetadata.FileSizeBytes, - thumbnailMetadata.MediaMetadata.CreationTimestamp, - thumbnailMetadata.ThumbnailSize.Width, - thumbnailMetadata.ThumbnailSize.Height, - thumbnailMetadata.ThumbnailSize.ResizeMethod, - ) - return err - }) +func (s *thumbnailStatements) InsertThumbnail(ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata) error { + thumbnailMetadata.MediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now()) + _, err := sqlutil.TxStmtContext(ctx, txn, s.insertThumbnailStmt).ExecContext( + ctx, + thumbnailMetadata.MediaMetadata.MediaID, + thumbnailMetadata.MediaMetadata.Origin, + thumbnailMetadata.MediaMetadata.ContentType, + thumbnailMetadata.MediaMetadata.FileSizeBytes, + thumbnailMetadata.MediaMetadata.CreationTimestamp, + thumbnailMetadata.ThumbnailSize.Width, + thumbnailMetadata.ThumbnailSize.Height, + thumbnailMetadata.ThumbnailSize.ResizeMethod, + ) + return err } -func (s *thumbnailStatements) selectThumbnail( +func (s *thumbnailStatements) SelectThumbnail( ctx context.Context, + txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, @@ -119,7 +113,7 @@ func (s *thumbnailStatements) selectThumbnail( ResizeMethod: resizeMethod, }, } - err := s.selectThumbnailStmt.QueryRowContext( + err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailStmt).QueryRowContext( ctx, thumbnailMetadata.MediaMetadata.MediaID, thumbnailMetadata.MediaMetadata.Origin, @@ -134,10 +128,11 @@ func (s *thumbnailStatements) selectThumbnail( return &thumbnailMetadata, err } -func (s *thumbnailStatements) selectThumbnails( - ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, +func (s *thumbnailStatements) SelectThumbnails( + ctx context.Context, txn *sql.Tx, mediaID types.MediaID, + mediaOrigin gomatrixserverlib.ServerName, ) ([]*types.ThumbnailMetadata, error) { - rows, err := s.selectThumbnailsStmt.QueryContext( + rows, err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailsStmt).QueryContext( ctx, mediaID, mediaOrigin, ) if err != nil { diff --git a/mediaapi/storage/storage.go b/mediaapi/storage/storage.go index 56059f1c8..baa242e57 100644 --- a/mediaapi/storage/storage.go +++ b/mediaapi/storage/storage.go @@ -25,13 +25,13 @@ import ( "github.com/matrix-org/dendrite/setup/config" ) -// Open opens a postgres database. -func Open(dbProperties *config.DatabaseOptions) (Database, error) { +// NewMediaAPIDatasource opens a database connection. +func NewMediaAPIDatasource(dbProperties *config.DatabaseOptions) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.Open(dbProperties) + return sqlite3.NewDatabase(dbProperties) case dbProperties.ConnectionString.IsPostgres(): - return postgres.Open(dbProperties) + return postgres.NewDatabase(dbProperties) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/mediaapi/storage/storage_test.go b/mediaapi/storage/storage_test.go new file mode 100644 index 000000000..8d3403045 --- /dev/null +++ b/mediaapi/storage/storage_test.go @@ -0,0 +1,135 @@ +package storage_test + +import ( + "context" + "reflect" + "testing" + + "github.com/matrix-org/dendrite/mediaapi/storage" + "github.com/matrix-org/dendrite/mediaapi/types" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" +) + +func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := storage.NewMediaAPIDatasource(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }) + if err != nil { + t.Fatalf("NewSyncServerDatasource returned %s", err) + } + return db, close +} +func TestMediaRepository(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateDatabase(t, dbType) + defer close() + ctx := context.Background() + t.Run("can insert media & query media", func(t *testing.T) { + metadata := &types.MediaMetadata{ + MediaID: "testing", + Origin: "localhost", + ContentType: "image/png", + FileSizeBytes: 10, + UploadName: "upload test", + Base64Hash: "dGVzdGluZw==", + UserID: "@alice:localhost", + } + if err := db.StoreMediaMetadata(ctx, metadata); err != nil { + t.Fatalf("unable to store media metadata: %v", err) + } + // query by media id + gotMetadata, err := db.GetMediaMetadata(ctx, metadata.MediaID, metadata.Origin) + if err != nil { + t.Fatalf("unable to query media metadata: %v", err) + } + if !reflect.DeepEqual(metadata, gotMetadata) { + t.Fatalf("expected metadata %+v, got %v", metadata, gotMetadata) + } + // query by media hash + gotMetadata, err = db.GetMediaMetadataByHash(ctx, metadata.Base64Hash, metadata.Origin) + if err != nil { + t.Fatalf("unable to query media metadata by hash: %v", err) + } + if !reflect.DeepEqual(metadata, gotMetadata) { + t.Fatalf("expected metadata %+v, got %v", metadata, gotMetadata) + } + }) + }) +} + +func TestThumbnailsStorage(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateDatabase(t, dbType) + defer close() + ctx := context.Background() + t.Run("can insert thumbnails & query media", func(t *testing.T) { + thumbnails := []*types.ThumbnailMetadata{ + { + MediaMetadata: &types.MediaMetadata{ + MediaID: "testing", + Origin: "localhost", + ContentType: "image/png", + FileSizeBytes: 6, + }, + ThumbnailSize: types.ThumbnailSize{ + Width: 5, + Height: 5, + ResizeMethod: types.Crop, + }, + }, + { + MediaMetadata: &types.MediaMetadata{ + MediaID: "testing", + Origin: "localhost", + ContentType: "image/png", + FileSizeBytes: 7, + }, + ThumbnailSize: types.ThumbnailSize{ + Width: 1, + Height: 1, + ResizeMethod: types.Scale, + }, + }, + } + for i := range thumbnails { + if err := db.StoreThumbnail(ctx, thumbnails[i]); err != nil { + t.Fatalf("unable to store thumbnail metadata: %v", err) + } + } + // query by single thumbnail + gotMetadata, err := db.GetThumbnail(ctx, + thumbnails[0].MediaMetadata.MediaID, + thumbnails[0].MediaMetadata.Origin, + thumbnails[0].ThumbnailSize.Width, thumbnails[0].ThumbnailSize.Height, + thumbnails[0].ThumbnailSize.ResizeMethod, + ) + if err != nil { + t.Fatalf("unable to query thumbnail metadata: %v", err) + } + if !reflect.DeepEqual(thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata) { + t.Fatalf("expected metadata %+v, got %+v", thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata) + } + if !reflect.DeepEqual(thumbnails[0].ThumbnailSize, gotMetadata.ThumbnailSize) { + t.Fatalf("expected metadata %+v, got %+v", thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata) + } + // query by all thumbnails + gotMediadatas, err := db.GetThumbnails(ctx, thumbnails[0].MediaMetadata.MediaID, thumbnails[0].MediaMetadata.Origin) + if err != nil { + t.Fatalf("unable to query media metadata by hash: %v", err) + } + if len(gotMediadatas) != len(thumbnails) { + t.Fatalf("expected %d stored thumbnail metadata, got %d", len(thumbnails), len(gotMediadatas)) + } + for i := range gotMediadatas { + if !reflect.DeepEqual(thumbnails[i].MediaMetadata, gotMediadatas[i].MediaMetadata) { + t.Fatalf("expected metadata %+v, got %v", thumbnails[i].MediaMetadata, gotMediadatas[i].MediaMetadata) + } + if !reflect.DeepEqual(thumbnails[i].ThumbnailSize, gotMediadatas[i].ThumbnailSize) { + t.Fatalf("expected metadata %+v, got %v", thumbnails[i].ThumbnailSize, gotMediadatas[i].ThumbnailSize) + } + } + }) + }) +} diff --git a/mediaapi/storage/storage_wasm.go b/mediaapi/storage/storage_wasm.go index a6e997b2a..f67f9d5e1 100644 --- a/mediaapi/storage/storage_wasm.go +++ b/mediaapi/storage/storage_wasm.go @@ -22,10 +22,10 @@ import ( ) // Open opens a postgres database. -func Open(dbProperties *config.DatabaseOptions) (Database, error) { +func NewMediaAPIDatasource(dbProperties *config.DatabaseOptions) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.Open(dbProperties) + return sqlite3.NewDatabase(dbProperties) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/mediaapi/storage/tables/interface.go b/mediaapi/storage/tables/interface.go new file mode 100644 index 000000000..bf63bc6ab --- /dev/null +++ b/mediaapi/storage/tables/interface.go @@ -0,0 +1,46 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/mediaapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +type Thumbnails interface { + InsertThumbnail(ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata) error + SelectThumbnail( + ctx context.Context, txn *sql.Tx, + mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, + width, height int, + resizeMethod string, + ) (*types.ThumbnailMetadata, error) + SelectThumbnails( + ctx context.Context, txn *sql.Tx, mediaID types.MediaID, + mediaOrigin gomatrixserverlib.ServerName, + ) ([]*types.ThumbnailMetadata, error) +} + +type MediaRepository interface { + InsertMedia(ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata) error + SelectMedia(ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) + SelectMediaByHash( + ctx context.Context, txn *sql.Tx, + mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, + ) (*types.MediaMetadata, error) +} diff --git a/mediaapi/types/types.go b/mediaapi/types/types.go index 0ba7010ad..ab28b3410 100644 --- a/mediaapi/types/types.go +++ b/mediaapi/types/types.go @@ -45,16 +45,13 @@ type RequestMethod string // MatrixUserID is a Matrix user ID string in the form @user:domain e.g. @alice:matrix.org type MatrixUserID string -// UnixMs is the milliseconds since the Unix epoch -type UnixMs int64 - // MediaMetadata is metadata associated with a media file type MediaMetadata struct { MediaID MediaID Origin gomatrixserverlib.ServerName ContentType ContentType FileSizeBytes FileSizeBytes - CreationTimestamp UnixMs + CreationTimestamp gomatrixserverlib.Timestamp UploadName Filename Base64Hash Base64Hash UserID MatrixUserID diff --git a/syncapi/notifier/notifier.go b/syncapi/notifier/notifier.go index 443744b6f..82834239b 100644 --- a/syncapi/notifier/notifier.go +++ b/syncapi/notifier/notifier.go @@ -36,7 +36,7 @@ import ( type Notifier struct { lock *sync.RWMutex // A map of RoomID => Set : Must only be accessed by the OnNewEvent goroutine - roomIDToJoinedUsers map[string]userIDSet + roomIDToJoinedUsers map[string]*userIDSet // A map of RoomID => Set : Must only be accessed by the OnNewEvent goroutine roomIDToPeekingDevices map[string]peekingDeviceSet // The latest sync position @@ -54,7 +54,7 @@ type Notifier struct { // the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase). func NewNotifier() *Notifier { return &Notifier{ - roomIDToJoinedUsers: make(map[string]userIDSet), + roomIDToJoinedUsers: make(map[string]*userIDSet), roomIDToPeekingDevices: make(map[string]peekingDeviceSet), userDeviceStreams: make(map[string]map[string]*UserDeviceStream), lock: &sync.RWMutex{}, @@ -262,7 +262,7 @@ func (n *Notifier) SharedUsers(userID string) []string { func (n *Notifier) _sharedUsers(userID string) []string { n._sharedUserMap[userID] = struct{}{} for roomID, users := range n.roomIDToJoinedUsers { - if _, ok := users[userID]; !ok { + if ok := users.isIn(userID); !ok { continue } for _, userID := range n._joinedUsers(roomID) { @@ -282,8 +282,11 @@ func (n *Notifier) IsSharedUser(userA, userB string) bool { defer n.lock.RUnlock() var okA, okB bool for _, users := range n.roomIDToJoinedUsers { - _, okA = users[userA] - _, okB = users[userB] + okA = users.isIn(userA) + if !okA { + continue + } + okB = users.isIn(userB) if okA && okB { return true } @@ -345,11 +348,12 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) { // This is just the bulk form of addJoinedUser for roomID, userIDs := range roomIDToUserIDs { if _, ok := n.roomIDToJoinedUsers[roomID]; !ok { - n.roomIDToJoinedUsers[roomID] = make(userIDSet, len(userIDs)) + n.roomIDToJoinedUsers[roomID] = newUserIDSet(len(userIDs)) } for _, userID := range userIDs { n.roomIDToJoinedUsers[roomID].add(userID) } + n.roomIDToJoinedUsers[roomID].precompute() } } @@ -440,16 +444,18 @@ func (n *Notifier) _fetchUserStreams(userID string) []*UserDeviceStream { func (n *Notifier) _addJoinedUser(roomID, userID string) { if _, ok := n.roomIDToJoinedUsers[roomID]; !ok { - n.roomIDToJoinedUsers[roomID] = make(userIDSet) + n.roomIDToJoinedUsers[roomID] = newUserIDSet(8) } n.roomIDToJoinedUsers[roomID].add(userID) + n.roomIDToJoinedUsers[roomID].precompute() } func (n *Notifier) _removeJoinedUser(roomID, userID string) { if _, ok := n.roomIDToJoinedUsers[roomID]; !ok { - n.roomIDToJoinedUsers[roomID] = make(userIDSet) + n.roomIDToJoinedUsers[roomID] = newUserIDSet(8) } n.roomIDToJoinedUsers[roomID].remove(userID) + n.roomIDToJoinedUsers[roomID].precompute() } func (n *Notifier) JoinedUsers(roomID string) (userIDs []string) { @@ -521,19 +527,52 @@ func (n *Notifier) _removeEmptyUserStreams() { } // A string set, mainly existing for improving clarity of structs in this file. -type userIDSet map[string]struct{} - -func (s userIDSet) add(str string) { - s[str] = struct{}{} +type userIDSet struct { + sync.Mutex + set map[string]struct{} + precomputed []string } -func (s userIDSet) remove(str string) { - delete(s, str) +func newUserIDSet(cap int) *userIDSet { + return &userIDSet{ + set: make(map[string]struct{}, cap), + precomputed: nil, + } } -func (s userIDSet) values() (vals []string) { - vals = make([]string, 0, len(s)) - for str := range s { +func (s *userIDSet) add(str string) { + s.Lock() + defer s.Unlock() + s.set[str] = struct{}{} + s.precomputed = s.precomputed[:0] // invalidate cache +} + +func (s *userIDSet) remove(str string) { + s.Lock() + defer s.Unlock() + delete(s.set, str) + s.precomputed = s.precomputed[:0] // invalidate cache +} + +func (s *userIDSet) precompute() { + s.Lock() + defer s.Unlock() + s.precomputed = s.values() +} + +func (s *userIDSet) isIn(str string) bool { + s.Lock() + defer s.Unlock() + _, ok := s.set[str] + return ok +} + +func (s *userIDSet) values() (vals []string) { + if len(s.precomputed) > 0 { + return s.precomputed // only return if not invalidated + } + vals = make([]string, 0, len(s.set)) + for str := range s.set { vals = append(vals, str) } return diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index 2412bc2ae..aaa0c61bf 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -60,7 +60,9 @@ func Context( Headers: nil, } } - filter.Rooms = append(filter.Rooms, roomID) + if filter.Rooms != nil { + *filter.Rooms = append(*filter.Rooms, roomID) + } ctx := req.Context() membershipRes := roomserver.QueryMembershipForUserResponse{} diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 36ba3a3e6..519aeff68 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -262,12 +262,8 @@ func (r *messagesReq) retrieveEvents() ( clientEvents []gomatrixserverlib.ClientEvent, start, end types.TopologyToken, err error, ) { - eventFilter := r.filter - // Retrieve the events from the local database. - streamEvents, err := r.db.GetEventsInTopologicalRange( - r.ctx, r.from, r.to, r.roomID, eventFilter.Limit, r.backwardOrdering, - ) + streamEvents, err := r.db.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering) if err != nil { err = fmt.Errorf("GetEventsInRange: %w", err) return diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 841f67261..14cb08a52 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -104,8 +104,8 @@ type Database interface { // DeletePeek deletes all peeks for a given room by a given user // Returns an error if there was a problem communicating with the database. DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error) - // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. - GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) + // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last. + GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error) // EventPositionInTopology returns the depth and stream position of the given event. EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error) // BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events. diff --git a/syncapi/storage/postgres/backwards_extremities_table.go b/syncapi/storage/postgres/backwards_extremities_table.go index d5cf563a6..d4515735c 100644 --- a/syncapi/storage/postgres/backwards_extremities_table.go +++ b/syncapi/storage/postgres/backwards_extremities_table.go @@ -47,14 +47,10 @@ const selectBackwardExtremitiesForRoomSQL = "" + const deleteBackwardExtremitySQL = "" + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" -const deleteBackwardExtremitiesForRoomSQL = "" + - "DELETE FROM syncapi_backward_extremities WHERE room_id = $1" - type backwardExtremitiesStatements struct { insertBackwardExtremityStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt - deleteBackwardExtremitiesForRoomStmt *sql.Stmt } func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { @@ -72,9 +68,6 @@ func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremiti if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { return nil, err } - if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil { - return nil, err - } return s, nil } @@ -113,10 +106,3 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) return } - -func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom( - ctx context.Context, txn *sql.Tx, roomID string, -) (err error) { - _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID) - return err -} diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 69e6e30ec..fe68788d1 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -233,9 +233,10 @@ func (s *currentRoomStateStatements) SelectCurrentState( excludeEventIDs []string, ) ([]*gomatrixserverlib.HeaderedEvent, error) { stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt) + senders, notSenders := getSendersStateFilterFilter(stateFilter) rows, err := stmt.QueryContext(ctx, roomID, - pq.StringArray(stateFilter.Senders), - pq.StringArray(stateFilter.NotSenders), + pq.StringArray(senders), + pq.StringArray(notSenders), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)), stateFilter.ContainsURL, diff --git a/syncapi/storage/postgres/filtering.go b/syncapi/storage/postgres/filtering.go index dcc421362..a2ca42156 100644 --- a/syncapi/storage/postgres/filtering.go +++ b/syncapi/storage/postgres/filtering.go @@ -16,21 +16,45 @@ package postgres import ( "strings" + + "github.com/matrix-org/gomatrixserverlib" ) // filterConvertWildcardToSQL converts wildcards as defined in // https://matrix.org/docs/spec/client_server/r0.3.0.html#post-matrix-client-r0-user-userid-filter // to SQL wildcards that can be used with LIKE() -func filterConvertTypeWildcardToSQL(values []string) []string { +func filterConvertTypeWildcardToSQL(values *[]string) []string { if values == nil { // Return nil instead of []string{} so IS NULL can work correctly when // the return value is passed into SQL queries return nil } - ret := make([]string, len(values)) - for i := range values { - ret[i] = strings.Replace(values[i], "*", "%", -1) + v := *values + ret := make([]string, len(v)) + for i := range v { + ret[i] = strings.Replace(v[i], "*", "%", -1) } return ret } + +// TODO: Replace when Dendrite uses Go 1.18 +func getSendersRoomEventFilter(filter *gomatrixserverlib.RoomEventFilter) (senders []string, notSenders []string) { + if filter.Senders != nil { + senders = *filter.Senders + } + if filter.NotSenders != nil { + notSenders = *filter.NotSenders + } + return senders, notSenders +} + +func getSendersStateFilterFilter(filter *gomatrixserverlib.StateFilter) (senders []string, notSenders []string) { + if filter.Senders != nil { + senders = *filter.Senders + } + if filter.NotSenders != nil { + notSenders = *filter.NotSenders + } + return senders, notSenders +} diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go index 1242a3221..39fa656cb 100644 --- a/syncapi/storage/postgres/memberships_table.go +++ b/syncapi/storage/postgres/memberships_table.go @@ -56,12 +56,6 @@ const upsertMembershipSQL = "" + " ON CONFLICT ON CONSTRAINT syncapi_memberships_unique" + " DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6" -const selectMembershipSQL = "" + - "SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" + - " WHERE room_id = $1 AND user_id = $2 AND membership = ANY($3)" + - " ORDER BY stream_pos DESC" + - " LIMIT 1" - const selectMembershipCountSQL = "" + "SELECT COUNT(*) FROM (" + " SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" + @@ -69,7 +63,6 @@ const selectMembershipCountSQL = "" + type membershipsStatements struct { upsertMembershipStmt *sql.Stmt - selectMembershipStmt *sql.Stmt selectMembershipCountStmt *sql.Stmt } @@ -82,9 +75,6 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil { return nil, err } - if s.selectMembershipStmt, err = db.Prepare(selectMembershipSQL); err != nil { - return nil, err - } if s.selectMembershipCountStmt, err = db.Prepare(selectMembershipCountSQL); err != nil { return nil, err } @@ -111,14 +101,6 @@ func (s *membershipsStatements) UpsertMembership( return err } -func (s *membershipsStatements) SelectMembership( - ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string, -) (eventID string, streamPos, topologyPos types.StreamPosition, err error) { - stmt := sqlutil.TxStmt(txn, s.selectMembershipStmt) - err = stmt.QueryRowContext(ctx, roomID, userID, memberships).Scan(&eventID, &streamPos, &topologyPos) - return -} - func (s *membershipsStatements) SelectMembershipCount( ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition, ) (count int, err error) { diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 14af6a949..17e2feab6 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -81,6 +81,15 @@ const insertEventSQL = "" + const selectEventsSQL = "" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" +const selectEventsWithFilterSQL = "" + + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" + + " AND ( $2::text[] IS NULL OR sender = ANY($2) )" + + " AND ( $3::text[] IS NULL OR NOT(sender = ANY($3)) )" + + " AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" + + " AND ( $5::text[] IS NULL OR NOT(type LIKE ANY($5)) )" + + " AND ( $6::bool IS NULL OR contains_url = $6 )" + + " LIMIT $7" + const selectRecentEventsSQL = "" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + " WHERE room_id = $1 AND id > $2 AND id <= $3" + @@ -153,6 +162,7 @@ const selectContextAfterEventSQL = "" + type outputRoomEventsStatements struct { insertEventStmt *sql.Stmt selectEventsStmt *sql.Stmt + selectEventsWitFilterStmt *sql.Stmt selectMaxEventIDStmt *sql.Stmt selectRecentEventsStmt *sql.Stmt selectRecentEventsForSyncStmt *sql.Stmt @@ -174,6 +184,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { return s, sqlutil.StatementList{ {&s.insertEventStmt, insertEventSQL}, {&s.selectEventsStmt, selectEventsSQL}, + {&s.selectEventsWitFilterStmt, selectEventsWithFilterSQL}, {&s.selectMaxEventIDStmt, selectMaxEventIDSQL}, {&s.selectRecentEventsStmt, selectRecentEventsSQL}, {&s.selectRecentEventsForSyncStmt, selectRecentEventsForSyncSQL}, @@ -204,11 +215,11 @@ func (s *outputRoomEventsStatements) SelectStateInRange( stateFilter *gomatrixserverlib.StateFilter, roomIDs []string, ) (map[string]map[string]bool, map[string]types.StreamEvent, error) { stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt) - + senders, notSenders := getSendersStateFilterFilter(stateFilter) rows, err := stmt.QueryContext( ctx, r.Low(), r.High(), pq.StringArray(roomIDs), - pq.StringArray(stateFilter.Senders), - pq.StringArray(stateFilter.NotSenders), + pq.StringArray(senders), + pq.StringArray(notSenders), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)), stateFilter.ContainsURL, @@ -310,7 +321,7 @@ func (s *outputRoomEventsStatements) InsertEvent( // Parse content as JSON and search for an "url" key containsURL := false var content map[string]interface{} - if json.Unmarshal(event.Content(), &content) != nil { + if json.Unmarshal(event.Content(), &content) == nil { // Set containsURL to true if url is present _, containsURL = content["url"] } @@ -353,10 +364,11 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( } else { stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt) } + senders, notSenders := getSendersRoomEventFilter(eventFilter) rows, err := stmt.QueryContext( ctx, roomID, r.Low(), r.High(), - pq.StringArray(eventFilter.Senders), - pq.StringArray(eventFilter.NotSenders), + pq.StringArray(senders), + pq.StringArray(notSenders), pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)), eventFilter.Limit+1, @@ -398,11 +410,12 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, ) ([]types.StreamEvent, error) { + senders, notSenders := getSendersRoomEventFilter(eventFilter) stmt := sqlutil.TxStmt(txn, s.selectEarlyEventsStmt) rows, err := stmt.QueryContext( ctx, roomID, r.Low(), r.High(), - pq.StringArray(eventFilter.Senders), - pq.StringArray(eventFilter.NotSenders), + pq.StringArray(senders), + pq.StringArray(notSenders), pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)), eventFilter.Limit, @@ -427,15 +440,52 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( // selectEvents returns the events for the given event IDs. If an event is // missing from the database, it will be omitted. func (s *outputRoomEventsStatements) SelectEvents( - ctx context.Context, txn *sql.Tx, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool, ) ([]types.StreamEvent, error) { - stmt := sqlutil.TxStmt(txn, s.selectEventsStmt) - rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) + var ( + stmt *sql.Stmt + rows *sql.Rows + err error + ) + if filter == nil { + stmt = sqlutil.TxStmt(txn, s.selectEventsStmt) + rows, err = stmt.QueryContext(ctx, pq.StringArray(eventIDs)) + } else { + senders, notSenders := getSendersRoomEventFilter(filter) + stmt = sqlutil.TxStmt(txn, s.selectEventsWitFilterStmt) + rows, err = stmt.QueryContext(ctx, + pq.StringArray(eventIDs), + pq.StringArray(senders), + pq.StringArray(notSenders), + pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)), + pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)), + filter.ContainsURL, + filter.Limit, + ) + } if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed") - return rowsToStreamEvents(rows) + streamEvents, err := rowsToStreamEvents(rows) + if err != nil { + return nil, err + } + if preserveOrder { + eventMap := make(map[string]types.StreamEvent) + for _, ev := range streamEvents { + eventMap[ev.EventID()] = ev + } + var returnEvents []types.StreamEvent + for _, eventID := range eventIDs { + ev, ok := eventMap[eventID] + if ok { + returnEvents = append(returnEvents, ev) + } + } + return returnEvents, nil + } + return streamEvents, nil } func (s *outputRoomEventsStatements) DeleteEventsForRoom( @@ -462,10 +512,11 @@ func (s *outputRoomEventsStatements) SelectContextEvent(ctx context.Context, txn func (s *outputRoomEventsStatements) SelectContextBeforeEvent( ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter, ) (evts []*gomatrixserverlib.HeaderedEvent, err error) { + senders, notSenders := getSendersRoomEventFilter(filter) rows, err := sqlutil.TxStmt(txn, s.selectContextBeforeEventStmt).QueryContext( ctx, roomID, id, filter.Limit, - pq.StringArray(filter.Senders), - pq.StringArray(filter.NotSenders), + pq.StringArray(senders), + pq.StringArray(notSenders), pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)), ) @@ -494,10 +545,11 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent( func (s *outputRoomEventsStatements) SelectContextAfterEvent( ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter, ) (lastID int, evts []*gomatrixserverlib.HeaderedEvent, err error) { + senders, notSenders := getSendersRoomEventFilter(filter) rows, err := sqlutil.TxStmt(txn, s.selectContextAfterEventStmt).QueryContext( ctx, roomID, id, filter.Limit, - pq.StringArray(filter.Senders), - pq.StringArray(filter.NotSenders), + pq.StringArray(senders), + pq.StringArray(notSenders), pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)), ) diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go index 626386ba0..a1fc9b2a3 100644 --- a/syncapi/storage/postgres/output_room_events_topology_table.go +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -73,9 +73,6 @@ const selectMaxPositionInTopologySQL = "" + "SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" + ") ORDER BY stream_position DESC LIMIT 1" -const deleteTopologyForRoomSQL = "" + - "DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1" - const selectStreamToTopologicalPositionAscSQL = "" + "SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;" @@ -88,7 +85,6 @@ type outputRoomEventsTopologyStatements struct { selectEventIDsInRangeDESCStmt *sql.Stmt selectPositionInTopologyStmt *sql.Stmt selectMaxPositionInTopologyStmt *sql.Stmt - deleteTopologyForRoomStmt *sql.Stmt selectStreamToTopologicalPositionAscStmt *sql.Stmt selectStreamToTopologicalPositionDescStmt *sql.Stmt } @@ -114,9 +110,6 @@ func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) { if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { return nil, err } - if s.deleteTopologyForRoomStmt, err = db.Prepare(deleteTopologyForRoomSQL); err != nil { - return nil, err - } if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil { return nil, err } @@ -148,9 +141,9 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( // is requested or not. var stmt *sql.Stmt if chronologicalOrder { - stmt = s.selectEventIDsInRangeASCStmt + stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeASCStmt) } else { - stmt = s.selectEventIDsInRangeDESCStmt + stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeDESCStmt) } // Query the event IDs. @@ -203,10 +196,3 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) return } - -func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom( - ctx context.Context, txn *sql.Tx, roomID string, -) (err error) { - _, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID) - return err -} diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 1c45d5d9a..2143fd672 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -150,7 +150,7 @@ func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, stre // Returns an error if there was a problem talking with the database. // Does not include any transaction IDs in the returned events. func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { - streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs) + streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, nil, false) if err != nil { return nil, err } @@ -312,7 +312,7 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e // Check if we have all of the event's previous events. If an event is // missing, add it to the room's backward extremities. - prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs()) + prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), nil, false) if err != nil { return err } @@ -429,7 +429,8 @@ func (d *Database) updateRoomState( func (d *Database) GetEventsInTopologicalRange( ctx context.Context, from, to *types.TopologyToken, - roomID string, limit int, + roomID string, + filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool, ) (events []types.StreamEvent, err error) { var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition @@ -450,14 +451,14 @@ func (d *Database) GetEventsInTopologicalRange( // Select the event IDs from the defined range. var eIDs []string eIDs, err = d.Topology.SelectEventIDsInRange( - ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, limit, !backwardOrdering, + ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering, ) if err != nil { return } // Retrieve the events' contents using their IDs. - events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs) + events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs, filter, true) return } @@ -619,7 +620,7 @@ func (d *Database) fetchMissingStateEvents( ) ([]types.StreamEvent, error) { // Fetch from the events table first so we pick up the stream ID for the // event. - events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs) + events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, nil, false) if err != nil { return nil, err } @@ -687,6 +688,9 @@ func (d *Database) GetStateDeltas( // user has ever interacted with — joined to, kicked/banned from, left. memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID) if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } return nil, nil, err } @@ -704,17 +708,23 @@ func (d *Database) GetStateDeltas( // get all the state events ever (i.e. for all available rooms) between these two positions stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs) if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } return nil, nil, err } state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } return nil, nil, err } // find out which rooms this user is peeking, if any. // We do this before joins so any peeks get overwritten peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r) - if err != nil { + if err != nil && err != sql.ErrNoRows { return nil, nil, err } @@ -725,6 +735,9 @@ func (d *Database) GetStateDeltas( var s []types.StreamEvent s, err = d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter) if err != nil { + if err == sql.ErrNoRows { + continue + } return nil, nil, err } state[peek.RoomID] = s @@ -752,6 +765,9 @@ func (d *Database) GetStateDeltas( var s []types.StreamEvent s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter) if err != nil { + if err == sql.ErrNoRows { + continue + } return nil, nil, err } state[roomID] = s @@ -802,6 +818,9 @@ func (d *Database) GetStateDeltasForFullStateSync( // user has ever interacted with — joined to, kicked/banned from, left. memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID) if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } return nil, nil, err } @@ -818,7 +837,7 @@ func (d *Database) GetStateDeltasForFullStateSync( deltas := make(map[string]types.StateDelta) peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r) - if err != nil { + if err != nil && err != sql.ErrNoRows { return nil, nil, err } @@ -827,6 +846,9 @@ func (d *Database) GetStateDeltasForFullStateSync( if !peek.Deleted { s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter) if stateErr != nil { + if stateErr == sql.ErrNoRows { + continue + } return nil, nil, stateErr } deltas[peek.RoomID] = types.StateDelta{ @@ -840,10 +862,16 @@ func (d *Database) GetStateDeltasForFullStateSync( // Get all the state events ever between these two positions stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs) if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } return nil, nil, err } state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } return nil, nil, err } @@ -868,6 +896,9 @@ func (d *Database) GetStateDeltasForFullStateSync( for _, joinedRoomID := range joinedRoomIDs { s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilter) if stateErr != nil { + if stateErr == sql.ErrNoRows { + continue + } return nil, nil, stateErr } deltas[joinedRoomID] = types.StateDelta{ diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index 24c442240..71a098177 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -41,23 +41,23 @@ const insertAccountDataSQL = "" + " ON CONFLICT (user_id, room_id, type) DO UPDATE" + " SET id = $5" +// further parameters are added by prepareWithFilters const selectAccountDataInRangeSQL = "" + "SELECT room_id, type FROM syncapi_account_data_type" + - " WHERE user_id = $1 AND id > $2 AND id <= $3" + - " ORDER BY id ASC" + " WHERE user_id = $1 AND id > $2 AND id <= $3" const selectMaxAccountDataIDSQL = "" + "SELECT MAX(id) FROM syncapi_account_data_type" type accountDataStatements struct { db *sql.DB - streamIDStatements *streamIDStatements + streamIDStatements *StreamIDStatements insertAccountDataStmt *sql.Stmt selectMaxAccountDataIDStmt *sql.Stmt selectAccountDataInRangeStmt *sql.Stmt } -func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) { +func NewSqliteAccountDataTable(db *sql.DB, streamID *StreamIDStatements) (tables.AccountData, error) { s := &accountDataStatements{ db: db, streamIDStatements: streamID, @@ -94,18 +94,24 @@ func (s *accountDataStatements) SelectAccountDataInRange( ctx context.Context, userID string, r types.Range, - accountDataFilterPart *gomatrixserverlib.EventFilter, + filter *gomatrixserverlib.EventFilter, ) (data map[string][]string, err error) { data = make(map[string][]string) + stmt, params, err := prepareWithFilters( + s.db, nil, selectAccountDataInRangeSQL, + []interface{}{ + userID, r.Low(), r.High(), + }, + filter.Senders, filter.NotSenders, + filter.Types, filter.NotTypes, + []string{}, nil, filter.Limit, FilterOrderAsc) - rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High()) + rows, err := stmt.QueryContext(ctx, params...) if err != nil { return } defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed") - var entries int - for rows.Next() { var dataType string var roomID string @@ -114,31 +120,11 @@ func (s *accountDataStatements) SelectAccountDataInRange( return } - // check if we should add this by looking at the filter. - // It would be nice if we could do this in SQL-land, but the mix of variadic - // and positional parameters makes the query annoyingly hard to do, it's easier - // and clearer to do it in Go-land. If there are no filters for [not]types then - // this gets skipped. - for _, includeType := range accountDataFilterPart.Types { - if includeType != dataType { // TODO: wildcard support - continue - } - } - for _, excludeType := range accountDataFilterPart.NotTypes { - if excludeType == dataType { // TODO: wildcard support - continue - } - } - if len(data[roomID]) > 0 { data[roomID] = append(data[roomID], dataType) } else { data[roomID] = []string{dataType} } - entries++ - if entries >= accountDataFilterPart.Limit { - break - } } return data, nil diff --git a/syncapi/storage/sqlite3/backwards_extremities_table.go b/syncapi/storage/sqlite3/backwards_extremities_table.go index 662cb0252..c5674dded 100644 --- a/syncapi/storage/sqlite3/backwards_extremities_table.go +++ b/syncapi/storage/sqlite3/backwards_extremities_table.go @@ -47,15 +47,11 @@ const selectBackwardExtremitiesForRoomSQL = "" + const deleteBackwardExtremitySQL = "" + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" -const deleteBackwardExtremitiesForRoomSQL = "" + - "DELETE FROM syncapi_backward_extremities WHERE room_id = $1" - type backwardExtremitiesStatements struct { db *sql.DB insertBackwardExtremityStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt - deleteBackwardExtremitiesForRoomStmt *sql.Stmt } func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { @@ -75,9 +71,6 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { return nil, err } - if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil { - return nil, err - } return s, nil } @@ -116,10 +109,3 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) return err } - -func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom( - ctx context.Context, txn *sql.Tx, roomID string, -) (err error) { - _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID) - return err -} diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 473aa49b0..ccda005c1 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -90,7 +90,7 @@ const selectEventsWithEventIDsSQL = "" + type currentRoomStateStatements struct { db *sql.DB - streamIDStatements *streamIDStatements + streamIDStatements *StreamIDStatements upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt deleteRoomStateForRoomStmt *sql.Stmt @@ -100,7 +100,7 @@ type currentRoomStateStatements struct { selectStateEventStmt *sql.Stmt } -func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) { +func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) { s := ¤tRoomStateStatements{ db: db, streamIDStatements: streamID, @@ -220,7 +220,7 @@ func (s *currentRoomStateStatements) SelectCurrentState( }, stateFilter.Senders, stateFilter.NotSenders, stateFilter.Types, stateFilter.NotTypes, - excludeEventIDs, stateFilter.Limit, FilterOrderNone, + excludeEventIDs, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderNone, ) if err != nil { return nil, fmt.Errorf("s.prepareWithFilters: %w", err) diff --git a/syncapi/storage/sqlite3/filtering.go b/syncapi/storage/sqlite3/filtering.go index 11f3e647b..05edb7b8c 100644 --- a/syncapi/storage/sqlite3/filtering.go +++ b/syncapi/storage/sqlite3/filtering.go @@ -25,34 +25,53 @@ const ( // parts. func prepareWithFilters( db *sql.DB, txn *sql.Tx, query string, params []interface{}, - senders, notsenders, types, nottypes []string, excludeEventIDs []string, - limit int, order FilterOrder, + senders, notsenders, types, nottypes *[]string, excludeEventIDs []string, + containsURL *bool, limit int, order FilterOrder, ) (*sql.Stmt, []interface{}, error) { offset := len(params) - if count := len(senders); count > 0 { - query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset) - for _, v := range senders { - params, offset = append(params, v), offset+1 + if senders != nil { + if count := len(*senders); count > 0 { + query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset) + for _, v := range *senders { + params, offset = append(params, v), offset+1 + } + } else { + query += ` AND sender = ""` } } - if count := len(notsenders); count > 0 { - query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset) - for _, v := range notsenders { - params, offset = append(params, v), offset+1 + if notsenders != nil { + if count := len(*notsenders); count > 0 { + query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset) + for _, v := range *notsenders { + params, offset = append(params, v), offset+1 + } + } else { + query += ` AND sender NOT = ""` } } - if count := len(types); count > 0 { - query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset) - for _, v := range types { - params, offset = append(params, v), offset+1 + if types != nil { + if count := len(*types); count > 0 { + query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset) + for _, v := range *types { + params, offset = append(params, v), offset+1 + } + } else { + query += ` AND type = ""` } } - if count := len(nottypes); count > 0 { - query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset) - for _, v := range nottypes { - params, offset = append(params, v), offset+1 + if nottypes != nil { + if count := len(*nottypes); count > 0 { + query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset) + for _, v := range *nottypes { + params, offset = append(params, v), offset+1 + } + } else { + query += ` AND type NOT = ""` } } + if containsURL != nil { + query += fmt.Sprintf(" AND contains_url = %v", *containsURL) + } if count := len(excludeEventIDs); count > 0 { query += " AND event_id NOT IN " + sqlutil.QueryVariadicOffset(count, offset) for _, v := range excludeEventIDs { diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index 0a6823cc0..58ab8461e 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -59,14 +59,14 @@ const selectMaxInviteIDSQL = "" + type inviteEventsStatements struct { db *sql.DB - streamIDStatements *streamIDStatements + streamIDStatements *StreamIDStatements insertInviteEventStmt *sql.Stmt selectInviteEventsInRangeStmt *sql.Stmt deleteInviteEventStmt *sql.Stmt selectMaxInviteIDStmt *sql.Stmt } -func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) { +func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Invites, error) { s := &inviteEventsStatements{ db: db, streamIDStatements: streamID, diff --git a/syncapi/storage/sqlite3/memberships_table.go b/syncapi/storage/sqlite3/memberships_table.go index 776bf3da3..9f3530ccd 100644 --- a/syncapi/storage/sqlite3/memberships_table.go +++ b/syncapi/storage/sqlite3/memberships_table.go @@ -18,7 +18,6 @@ import ( "context" "database/sql" "fmt" - "strings" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" @@ -57,12 +56,6 @@ const upsertMembershipSQL = "" + " ON CONFLICT (room_id, user_id, membership)" + " DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6" -const selectMembershipSQL = "" + - "SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" + - " WHERE room_id = $1 AND user_id = $2 AND membership IN ($3)" + - " ORDER BY stream_pos DESC" + - " LIMIT 1" - const selectMembershipCountSQL = "" + "SELECT COUNT(*) FROM (" + " SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" + @@ -111,22 +104,6 @@ func (s *membershipsStatements) UpsertMembership( return err } -func (s *membershipsStatements) SelectMembership( - ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string, -) (eventID string, streamPos, topologyPos types.StreamPosition, err error) { - params := []interface{}{roomID, userID} - for _, membership := range memberships { - params = append(params, membership) - } - orig := strings.Replace(selectMembershipSQL, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1) - stmt, err := s.db.Prepare(orig) - if err != nil { - return "", 0, 0, err - } - err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos) - return -} - func (s *membershipsStatements) SelectMembershipCount( ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition, ) (count int, err error) { diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index acd959696..188f7582b 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -58,7 +58,7 @@ const insertEventSQL = "" + "ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)" const selectEventsSQL = "" + - "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id IN ($1)" const selectRecentEventsSQL = "" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + @@ -111,9 +111,8 @@ const selectContextAfterEventSQL = "" + type outputRoomEventsStatements struct { db *sql.DB - streamIDStatements *streamIDStatements + streamIDStatements *StreamIDStatements insertEventStmt *sql.Stmt - selectEventsStmt *sql.Stmt selectMaxEventIDStmt *sql.Stmt updateEventJSONStmt *sql.Stmt deleteEventsForRoomStmt *sql.Stmt @@ -122,7 +121,7 @@ type outputRoomEventsStatements struct { selectContextAfterEventStmt *sql.Stmt } -func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) { +func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Events, error) { s := &outputRoomEventsStatements{ db: db, streamIDStatements: streamID, @@ -133,7 +132,6 @@ func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Even } return s, sqlutil.StatementList{ {&s.insertEventStmt, insertEventSQL}, - {&s.selectEventsStmt, selectEventsSQL}, {&s.selectMaxEventIDStmt, selectMaxEventIDSQL}, {&s.updateEventJSONStmt, updateEventJSONSQL}, {&s.deleteEventsForRoomStmt, deleteEventsForRoomSQL}, @@ -170,7 +168,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange( s.db, txn, stmtSQL, inputParams, stateFilter.Senders, stateFilter.NotSenders, stateFilter.Types, stateFilter.NotTypes, - nil, stateFilter.Limit, FilterOrderAsc, + nil, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderAsc, ) if err != nil { return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err) @@ -279,7 +277,7 @@ func (s *outputRoomEventsStatements) InsertEvent( // Parse content as JSON and search for an "url" key containsURL := false var content map[string]interface{} - if json.Unmarshal(event.Content(), &content) != nil { + if json.Unmarshal(event.Content(), &content) == nil { // Set containsURL to true if url is present _, containsURL = content["url"] } @@ -347,7 +345,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( }, eventFilter.Senders, eventFilter.NotSenders, eventFilter.Types, eventFilter.NotTypes, - nil, eventFilter.Limit+1, FilterOrderDesc, + nil, eventFilter.ContainsURL, eventFilter.Limit+1, FilterOrderDesc, ) if err != nil { return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err) @@ -395,7 +393,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( }, eventFilter.Senders, eventFilter.NotSenders, eventFilter.Types, eventFilter.NotTypes, - nil, eventFilter.Limit, FilterOrderAsc, + nil, eventFilter.ContainsURL, eventFilter.Limit, FilterOrderAsc, ) if err != nil { return nil, fmt.Errorf("s.prepareWithFilters: %w", err) @@ -421,21 +419,50 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( // selectEvents returns the events for the given event IDs. If an event is // missing from the database, it will be omitted. func (s *outputRoomEventsStatements) SelectEvents( - ctx context.Context, txn *sql.Tx, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool, ) ([]types.StreamEvent, error) { - var returnEvents []types.StreamEvent - stmt := sqlutil.TxStmt(txn, s.selectEventsStmt) - for _, eventID := range eventIDs { - rows, err := stmt.QueryContext(ctx, eventID) - if err != nil { - return nil, err - } - if streamEvents, err := rowsToStreamEvents(rows); err == nil { - returnEvents = append(returnEvents, streamEvents...) - } - internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed") + iEventIDs := make([]interface{}, len(eventIDs)) + for i := range eventIDs { + iEventIDs[i] = eventIDs[i] } - return returnEvents, nil + selectSQL := strings.Replace(selectEventsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1) + + if filter == nil { + filter = &gomatrixserverlib.RoomEventFilter{Limit: 20} + } + stmt, params, err := prepareWithFilters( + s.db, txn, selectSQL, iEventIDs, + filter.Senders, filter.NotSenders, + filter.Types, filter.NotTypes, + nil, filter.ContainsURL, filter.Limit, FilterOrderAsc, + ) + if err != nil { + return nil, err + } + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed") + streamEvents, err := rowsToStreamEvents(rows) + if err != nil { + return nil, err + } + if preserveOrder { + var returnEvents []types.StreamEvent + eventMap := make(map[string]types.StreamEvent) + for _, ev := range streamEvents { + eventMap[ev.EventID()] = ev + } + for _, eventID := range eventIDs { + ev, ok := eventMap[eventID] + if ok { + returnEvents = append(returnEvents, ev) + } + } + return returnEvents, nil + } + return streamEvents, nil } func (s *outputRoomEventsStatements) DeleteEventsForRoom( @@ -507,7 +534,7 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent( }, filter.Senders, filter.NotSenders, filter.Types, filter.NotTypes, - nil, filter.Limit, FilterOrderDesc, + nil, filter.ContainsURL, filter.Limit, FilterOrderDesc, ) rows, err := stmt.QueryContext(ctx, params...) @@ -543,7 +570,7 @@ func (s *outputRoomEventsStatements) SelectContextAfterEvent( }, filter.Senders, filter.NotSenders, filter.Types, filter.NotTypes, - nil, filter.Limit, FilterOrderAsc, + nil, filter.ContainsURL, filter.Limit, FilterOrderAsc, ) rows, err := stmt.QueryContext(ctx, params...) diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index b972ae285..b2fb77417 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -78,7 +78,6 @@ type outputRoomEventsTopologyStatements struct { selectEventIDsInRangeDESCStmt *sql.Stmt selectPositionInTopologyStmt *sql.Stmt selectMaxPositionInTopologyStmt *sql.Stmt - deleteTopologyForRoomStmt *sql.Stmt selectStreamToTopologicalPositionAscStmt *sql.Stmt selectStreamToTopologicalPositionDescStmt *sql.Stmt } @@ -191,10 +190,3 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) return } - -func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom( - ctx context.Context, txn *sql.Tx, roomID string, -) (err error) { - _, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID) - return err -} diff --git a/syncapi/storage/sqlite3/peeks_table.go b/syncapi/storage/sqlite3/peeks_table.go index c93c82051..5ee86448c 100644 --- a/syncapi/storage/sqlite3/peeks_table.go +++ b/syncapi/storage/sqlite3/peeks_table.go @@ -66,7 +66,7 @@ const selectMaxPeekIDSQL = "" + type peekStatements struct { db *sql.DB - streamIDStatements *streamIDStatements + streamIDStatements *StreamIDStatements insertPeekStmt *sql.Stmt deletePeekStmt *sql.Stmt deletePeeksStmt *sql.Stmt @@ -75,7 +75,7 @@ type peekStatements struct { selectMaxPeekIDStmt *sql.Stmt } -func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks, error) { +func NewSqlitePeeksTable(db *sql.DB, streamID *StreamIDStatements) (tables.Peeks, error) { _, err := db.Exec(peeksSchema) if err != nil { return nil, err diff --git a/syncapi/storage/sqlite3/presence_table.go b/syncapi/storage/sqlite3/presence_table.go index e7b78a705..00b16458d 100644 --- a/syncapi/storage/sqlite3/presence_table.go +++ b/syncapi/storage/sqlite3/presence_table.go @@ -75,7 +75,7 @@ const selectPresenceAfter = "" + type presenceStatements struct { db *sql.DB - streamIDStatements *streamIDStatements + streamIDStatements *StreamIDStatements upsertPresenceStmt *sql.Stmt upsertPresenceFromSyncStmt *sql.Stmt selectPresenceForUsersStmt *sql.Stmt @@ -83,7 +83,7 @@ type presenceStatements struct { selectPresenceAfterStmt *sql.Stmt } -func NewSqlitePresenceTable(db *sql.DB, streamID *streamIDStatements) (*presenceStatements, error) { +func NewSqlitePresenceTable(db *sql.DB, streamID *StreamIDStatements) (*presenceStatements, error) { _, err := db.Exec(presenceSchema) if err != nil { return nil, err diff --git a/syncapi/storage/sqlite3/receipt_table.go b/syncapi/storage/sqlite3/receipt_table.go index dea057719..bd778bf3c 100644 --- a/syncapi/storage/sqlite3/receipt_table.go +++ b/syncapi/storage/sqlite3/receipt_table.go @@ -59,13 +59,13 @@ const selectMaxReceiptIDSQL = "" + type receiptStatements struct { db *sql.DB - streamIDStatements *streamIDStatements + streamIDStatements *StreamIDStatements upsertReceipt *sql.Stmt selectRoomReceipts *sql.Stmt selectMaxReceiptID *sql.Stmt } -func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Receipts, error) { +func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Receipts, error) { _, err := db.Exec(receiptsSchema) if err != nil { return nil, err diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go index faa2c41fe..71980b806 100644 --- a/syncapi/storage/sqlite3/stream_id_table.go +++ b/syncapi/storage/sqlite3/stream_id_table.go @@ -32,12 +32,12 @@ const increaseStreamIDStmt = "" + "UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" + " RETURNING stream_id" -type streamIDStatements struct { +type StreamIDStatements struct { db *sql.DB increaseStreamIDStmt *sql.Stmt } -func (s *streamIDStatements) prepare(db *sql.DB) (err error) { +func (s *StreamIDStatements) Prepare(db *sql.DB) (err error) { s.db = db _, err = db.Exec(streamIDTableSchema) if err != nil { @@ -49,31 +49,31 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) { return } -func (s *streamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { +func (s *StreamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) err = increaseStmt.QueryRowContext(ctx, "global").Scan(&pos) return } -func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { +func (s *StreamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) err = increaseStmt.QueryRowContext(ctx, "receipt").Scan(&pos) return } -func (s *streamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { +func (s *StreamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) err = increaseStmt.QueryRowContext(ctx, "invite").Scan(&pos) return } -func (s *streamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { +func (s *StreamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) err = increaseStmt.QueryRowContext(ctx, "accountdata").Scan(&pos) return } -func (s *streamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { +func (s *StreamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) err = increaseStmt.QueryRowContext(ctx, "presence").Scan(&pos) return diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 9d9d35988..dfc289482 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -30,7 +30,7 @@ type SyncServerDatasource struct { shared.Database db *sql.DB writer sqlutil.Writer - streamID streamIDStatements + streamID StreamIDStatements } // NewDatabase creates a new sync server database @@ -49,7 +49,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e } func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) { - if err = d.streamID.prepare(d.db); err != nil { + if err = d.streamID.Prepare(d.db); err != nil { return err } accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID) diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 403b50eaa..15bb769a2 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -3,6 +3,7 @@ package storage_test import ( "context" "fmt" + "reflect" "testing" "github.com/matrix-org/dendrite/setup/config" @@ -38,7 +39,7 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver if err != nil { t.Fatalf("WriteEvent failed: %s", err) } - fmt.Println("Event ID", ev.EventID(), " spos=", pos, "depth=", ev.Depth()) + t.Logf("Event ID %s spos=%v depth=%v", ev.EventID(), pos, ev.Depth()) positions = append(positions, pos) } return @@ -46,7 +47,6 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver func TestWriteEvents(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - t.Parallel() alice := test.NewUser() r := test.NewRoom(t, alice) db, close := MustCreateDatabase(t, dbType) @@ -61,84 +61,84 @@ func TestRecentEventsPDU(t *testing.T) { db, close := MustCreateDatabase(t, dbType) defer close() alice := test.NewUser() - var filter gomatrixserverlib.RoomEventFilter - filter.Limit = 100 + // dummy room to make sure SQL queries are filtering on room ID + MustWriteEvents(t, db, test.NewRoom(t, alice).Events()) + + // actual test room r := test.NewRoom(t, alice) r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hi"}) events := r.Events() positions := MustWriteEvents(t, db, events) + + // dummy room to make sure SQL queries are filtering on room ID + MustWriteEvents(t, db, test.NewRoom(t, alice).Events()) + latest, err := db.MaxStreamPositionForPDUs(ctx) if err != nil { t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err) } testCases := []struct { - Name string - From types.StreamPosition - To types.StreamPosition - WantEvents []*gomatrixserverlib.HeaderedEvent - WantLimited bool + Name string + From types.StreamPosition + To types.StreamPosition + Limit int + ReverseOrder bool + WantEvents []*gomatrixserverlib.HeaderedEvent + WantLimited bool }{ // The purpose of this test is to make sure that incremental syncs are including up to the latest events. - // It's a basic sanity test that sync works. It creates a `since` token that is on the penultimate event. + // It's a basic sanity test that sync works. It creates a streaming position that is on the penultimate event. // It makes sure the response includes the final event. { - Name: "IncrementalSync penultimate", + Name: "penultimate", From: positions[len(positions)-2], // pretend we are at the penultimate event To: latest, + Limit: 100, WantEvents: events[len(events)-1:], WantLimited: false, }, - /* - // The purpose of this test is to check that passing a `numRecentEventsPerRoom` correctly limits the - // number of returned events. This is critical for big rooms hence the test here. - { - Name: "IncrementalSync limited", - DoSync: func() (*types.Response, error) { - from := types.StreamingToken{ // pretend we are 10 events behind - PDUPosition: positions[len(positions)-11], - } - res := types.NewResponse() - // limit is set to 5 - return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) - }, - // want the last 5 events, NOT the last 10. - WantTimeline: events[len(events)-5:], - }, - // The purpose of this test is to check that CompleteSync returns all the current state as well as - // honouring the `numRecentEventsPerRoom` value - { - Name: "CompleteSync limited", - DoSync: func() (*types.Response, error) { - res := types.NewResponse() - // limit set to 5 - return db.CompleteSync(ctx, res, testUserDeviceA, 5) - }, - // want the last 5 events - WantTimeline: events[len(events)-5:], - // want all state for the room - WantState: state, - }, - // The purpose of this test is to check that CompleteSync can return everything with a high enough - // `numRecentEventsPerRoom`. - { - Name: "CompleteSync", - DoSync: func() (*types.Response, error) { - res := types.NewResponse() - return db.CompleteSync(ctx, res, testUserDeviceA, len(events)+1) - }, - WantTimeline: events, - // We want no state at all as that field in /sync is the delta between the token (beginning of time) - // and the START of the timeline. - }, */ + // The purpose of this test is to check that limits can be applied and work. + // This is critical for big rooms hence the test here. + { + Name: "limited", + From: 0, + To: latest, + Limit: 1, + WantEvents: events[len(events)-1:], + WantLimited: true, + }, + // The purpose of this test is to check that we can return every event with a high + // enough limit + { + Name: "large limited", + From: 0, + To: latest, + Limit: 100, + WantEvents: events, + WantLimited: false, + }, + // The purpose of this test is to check that we can return events in reverse order + { + Name: "reverse", + From: positions[len(positions)-3], // 2 events back + To: latest, + Limit: 100, + ReverseOrder: true, + WantEvents: test.Reversed(events[len(events)-2:]), + WantLimited: false, + }, } - for _, tc := range testCases { + for i := range testCases { + tc := testCases[i] t.Run(tc.Name, func(st *testing.T) { + var filter gomatrixserverlib.RoomEventFilter + filter.Limit = tc.Limit gotEvents, limited, err := db.RecentEvents(ctx, r.ID, types.Range{ From: tc.From, To: tc.To, - }, &filter, true, true) + }, &filter, !tc.ReverseOrder, true) if err != nil { st.Fatalf("failed to do sync: %s", err) } @@ -148,100 +148,49 @@ func TestRecentEventsPDU(t *testing.T) { if len(gotEvents) != len(tc.WantEvents) { st.Errorf("got %d events, want %d", len(gotEvents), len(tc.WantEvents)) } + for j := range gotEvents { + if !reflect.DeepEqual(gotEvents[j].JSON(), tc.WantEvents[j].JSON()) { + st.Errorf("event %d got %s want %s", j, string(gotEvents[j].JSON()), string(tc.WantEvents[j].JSON())) + } + } }) } }) } -/* -func TestGetEventsInRangeWithPrevBatch(t *testing.T) { - t.Parallel() - db := MustCreateDatabase(t) - events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) - positions := MustWriteEvents(t, db, events) - latest, err := db.SyncPosition(ctx) - if err != nil { - t.Fatalf("failed to get SyncPosition: %s", err) - } - from := types.StreamingToken{ - PDUPosition: positions[len(positions)-2], - } - - res := types.NewResponse() - res, err = db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) - if err != nil { - t.Fatalf("failed to IncrementalSync with latest token") - } - roomRes, ok := res.Rooms.Join[testRoomID] - if !ok { - t.Fatalf("IncrementalSync response missing room %s - response: %+v", testRoomID, res) - } - // returns the last event "Message 10" - assertEventsEqual(t, "IncrementalSync Timeline", false, roomRes.Timeline.Events, reversed(events[len(events)-1:])) - - prev := roomRes.Timeline.PrevBatch.String() - if prev == "" { - t.Fatalf("IncrementalSync expected prev_batch token") - } - prevBatchToken, err := types.NewTopologyTokenFromString(prev) - if err != nil { - t.Fatalf("failed to NewTopologyTokenFromString : %s", err) - } - // backpaginate 5 messages starting at the latest position. - // head towards the beginning of time - to := types.TopologyToken{} - paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &prevBatchToken, &to, testRoomID, 5, true) - if err != nil { - t.Fatalf("GetEventsInRange returned an error: %s", err) - } - gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll) - assertEventsEqual(t, "", true, gots, reversed(events[len(events)-6:len(events)-1])) -} - -// The purpose of this test is to ensure that backfill does indeed go backwards, using a stream token. -func TestGetEventsInRangeWithStreamToken(t *testing.T) { - t.Parallel() - db := MustCreateDatabase(t) - events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) - MustWriteEvents(t, db, events) - latest, err := db.SyncPosition(ctx) - if err != nil { - t.Fatalf("failed to get SyncPosition: %s", err) - } - // head towards the beginning of time - to := types.StreamingToken{} - - // backpaginate 5 messages starting at the latest position. - paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true) - if err != nil { - t.Fatalf("GetEventsInRange returned an error: %s", err) - } - gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll) - assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:])) -} - // The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token func TestGetEventsInRangeWithTopologyToken(t *testing.T) { - t.Parallel() - db := MustCreateDatabase(t) - events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) - MustWriteEvents(t, db, events) - from, err := db.MaxTopologicalPosition(ctx, testRoomID) - if err != nil { - t.Fatalf("failed to get MaxTopologicalPosition: %s", err) - } - // head towards the beginning of time - to := types.TopologyToken{} + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := MustCreateDatabase(t, dbType) + defer close() + alice := test.NewUser() + r := test.NewRoom(t, alice) + for i := 0; i < 10; i++ { + r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)}) + } + events := r.Events() + _ = MustWriteEvents(t, db, events) - // backpaginate 5 messages starting at the latest position. - paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, testRoomID, 5, true) - if err != nil { - t.Fatalf("GetEventsInRange returned an error: %s", err) - } - gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll) - assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:])) + from, err := db.MaxTopologicalPosition(ctx, r.ID) + if err != nil { + t.Fatalf("failed to get MaxTopologicalPosition: %s", err) + } + t.Logf("max topo pos = %+v", from) + // head towards the beginning of time + to := types.TopologyToken{} + + // backpaginate 5 messages starting at the latest position. + filter := &gomatrixserverlib.RoomEventFilter{Limit: 5} + paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true) + if err != nil { + t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err) + } + gots := db.StreamEventsToEvents(nil, paginatedEvents) + test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:])) + }) } +/* // The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth. // For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent // will appear FIRST when going backwards. This test creates a DAG like: @@ -651,12 +600,4 @@ func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *typ tok.Decrement() return &tok } - -func reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent { - out := make([]*gomatrixserverlib.HeaderedEvent, len(in)) - for i := 0; i < len(in); i++ { - out[i] = in[len(in)-i-1] - } - return out -} */ diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 8d368eec1..993e2022b 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -59,7 +59,7 @@ type Events interface { SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) // SelectEarlyEvents returns the earliest events in the given room. SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error) - SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error) + SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool) ([]types.StreamEvent, error) UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error // DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely. DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error) @@ -84,8 +84,6 @@ type Topology interface { SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error) // SelectMaxPositionInTopology returns the event which has the highest depth, and if there are multiple, the event with the highest stream position. SelectMaxPositionInTopology(ctx context.Context, txn *sql.Tx, roomID string) (depth types.StreamPosition, spos types.StreamPosition, err error) - // DeleteTopologyForRoom removes all topological information for a room. This should only be done when removing the room entirely. - DeleteTopologyForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error) // SelectStreamToTopologicalPosition converts a stream position to a topological position by finding the nearest topological position in the room. SelectStreamToTopologicalPosition(ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, forward bool) (topoPos types.StreamPosition, err error) } @@ -132,8 +130,6 @@ type BackwardsExtremities interface { SelectBackwardExtremitiesForRoom(ctx context.Context, roomID string) (bwExtrems map[string][]string, err error) // DeleteBackwardExtremity removes a backwards extremity for a room, if one existed. DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error) - // DeleteBackwardExtremitiesFoorRoomID removes all backward extremities for a room. This should only be done when removing the room entirely. - DeleteBackwardExtremitiesForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error) } // SendToDevice tracks send-to-device messages which are sent to individual @@ -173,7 +169,6 @@ type Receipts interface { type Memberships interface { UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error - SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string) (eventID string, streamPos, topologyPos types.StreamPosition, err error) SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error) } diff --git a/syncapi/storage/tables/output_room_events_test.go b/syncapi/storage/tables/output_room_events_test.go new file mode 100644 index 000000000..a143e5ecd --- /dev/null +++ b/syncapi/storage/tables/output_room_events_test.go @@ -0,0 +1,105 @@ +package tables_test + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/storage/postgres" + "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" +) + +func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events, *sql.DB, func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }) + if err != nil { + t.Fatalf("failed to open db: %s", err) + } + + var tab tables.Events + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresEventsTable(db) + case test.DBTypeSQLite: + var stream sqlite3.StreamIDStatements + if err = stream.Prepare(db); err != nil { + t.Fatalf("failed to prepare stream stmts: %s", err) + } + tab, err = sqlite3.NewSqliteEventsTable(db, &stream) + } + if err != nil { + t.Fatalf("failed to make new table: %s", err) + } + return tab, db, close +} + +func TestOutputRoomEventsTable(t *testing.T) { + ctx := context.Background() + alice := test.NewUser() + room := test.NewRoom(t, alice) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, db, close := newOutputRoomEventsTable(t, dbType) + defer close() + events := room.Events() + err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error { + for _, ev := range events { + _, err := tab.InsertEvent(ctx, txn, ev, nil, nil, nil, false) + if err != nil { + return fmt.Errorf("failed to InsertEvent: %s", err) + } + } + // order = 2,0,3,1 + wantEventIDs := []string{ + events[2].EventID(), events[0].EventID(), events[3].EventID(), events[1].EventID(), + } + gotEvents, err := tab.SelectEvents(ctx, txn, wantEventIDs, nil, true) + if err != nil { + return fmt.Errorf("failed to SelectEvents: %s", err) + } + gotEventIDs := make([]string, len(gotEvents)) + for i := range gotEvents { + gotEventIDs[i] = gotEvents[i].EventID() + } + if !reflect.DeepEqual(gotEventIDs, wantEventIDs) { + return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventIDs) + } + + // Test that contains_url is correctly populated + urlEv := room.CreateEvent(t, alice, "m.text", map[string]interface{}{ + "body": "test.txt", + "url": "mxc://test.txt", + }) + if _, err = tab.InsertEvent(ctx, txn, urlEv, nil, nil, nil, false); err != nil { + return fmt.Errorf("failed to InsertEvent: %s", err) + } + wantEventID := []string{urlEv.EventID()} + t := true + gotEvents, err = tab.SelectEvents(ctx, txn, wantEventID, &gomatrixserverlib.RoomEventFilter{Limit: 1, ContainsURL: &t}, true) + if err != nil { + return fmt.Errorf("failed to SelectEvents: %s", err) + } + gotEventIDs = make([]string, len(gotEvents)) + for i := range gotEvents { + gotEventIDs[i] = gotEvents[i].EventID() + } + if !reflect.DeepEqual(gotEventIDs, wantEventID) { + return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventID) + } + + return nil + }) + if err != nil { + t.Fatalf("err: %s", err) + } + }) +} diff --git a/syncapi/storage/tables/topology_test.go b/syncapi/storage/tables/topology_test.go new file mode 100644 index 000000000..b6ece0b0d --- /dev/null +++ b/syncapi/storage/tables/topology_test.go @@ -0,0 +1,91 @@ +package tables_test + +import ( + "context" + "database/sql" + "fmt" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/storage/postgres" + "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/test" +) + +func newTopologyTable(t *testing.T, dbType test.DBType) (tables.Topology, *sql.DB, func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }) + if err != nil { + t.Fatalf("failed to open db: %s", err) + } + + var tab tables.Topology + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresTopologyTable(db) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSqliteTopologyTable(db) + } + if err != nil { + t.Fatalf("failed to make new table: %s", err) + } + return tab, db, close +} + +func TestTopologyTable(t *testing.T) { + ctx := context.Background() + alice := test.NewUser() + room := test.NewRoom(t, alice) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, db, close := newTopologyTable(t, dbType) + defer close() + events := room.Events() + err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error { + var highestPos types.StreamPosition + for i, ev := range events { + topoPos, err := tab.InsertEventInTopology(ctx, txn, ev, types.StreamPosition(i)) + if err != nil { + return fmt.Errorf("failed to InsertEventInTopology: %s", err) + } + // topo pos = depth, depth starts at 1, hence 1+i + if topoPos != types.StreamPosition(1+i) { + return fmt.Errorf("got topo pos %d want %d", topoPos, 1+i) + } + highestPos = topoPos + 1 + } + // check ordering works without limit + eventIDs, err := tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, true) + if err != nil { + return fmt.Errorf("failed to SelectEventIDsInRange: %s", err) + } + test.AssertEventIDsEqual(t, eventIDs, events[:]) + eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, false) + if err != nil { + return fmt.Errorf("failed to SelectEventIDsInRange: %s", err) + } + test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[:])) + // check ordering works with limit + eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, true) + if err != nil { + return fmt.Errorf("failed to SelectEventIDsInRange: %s", err) + } + test.AssertEventIDsEqual(t, eventIDs, events[:3]) + eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, false) + if err != nil { + return fmt.Errorf("failed to SelectEventIDsInRange: %s", err) + } + test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[len(events)-3:])) + + return nil + }) + if err != nil { + t.Fatalf("err: %s", err) + } + }) +} diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index ab200e007..df5fb8e08 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -3,9 +3,11 @@ package streams import ( "context" "database/sql" + "fmt" "sync" "time" + "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -26,7 +28,8 @@ type PDUStreamProvider struct { tasks chan func() workers atomic.Int32 - userAPI userapi.UserInternalAPI + // userID+deviceID -> lazy loading cache + lazyLoadCache *caching.LazyLoadCache } func (p *PDUStreamProvider) worker() { @@ -188,7 +191,7 @@ func (p *PDUStreamProvider) IncrementalSync( newPos = from for _, delta := range stateDeltas { var pos types.StreamPosition - if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, r, delta, &eventFilter, req.Response); err != nil { + if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, r, delta, &eventFilter, &stateFilter, req.Response); err != nil { req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed") return to } @@ -203,12 +206,14 @@ func (p *PDUStreamProvider) IncrementalSync( return newPos } +// nolint:gocyclo func (p *PDUStreamProvider) addRoomDeltaToResponse( ctx context.Context, device *userapi.Device, r types.Range, delta types.StateDelta, eventFilter *gomatrixserverlib.RoomEventFilter, + stateFilter *gomatrixserverlib.StateFilter, res *types.Response, ) (types.StreamPosition, error) { if delta.MembershipPos > 0 && delta.Membership == gomatrixserverlib.Leave { @@ -225,13 +230,16 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( eventFilter, true, true, ) if err != nil { - return r.From, err + if err == sql.ErrNoRows { + return r.To, nil + } + return r.From, fmt.Errorf("p.DB.RecentEvents: %w", err) } recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents) delta.StateEvents = removeDuplicates(delta.StateEvents, recentEvents) // roll back prevBatch, err := p.DB.GetBackwardTopologyPos(ctx, recentStreamEvents) if err != nil { - return r.From, err + return r.From, fmt.Errorf("p.DB.GetBackwardTopologyPos: %w", err) } // If we didn't return any events at all then don't bother doing anything else. @@ -247,7 +255,8 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( // room that were returned. latestPosition := r.To updateLatestPosition := func(mostRecentEventID string) { - if _, pos, err := p.DB.PositionInTopology(ctx, mostRecentEventID); err == nil { + var pos types.StreamPosition + if _, pos, err = p.DB.PositionInTopology(ctx, mostRecentEventID); err == nil { switch { case r.Backwards && pos > latestPosition: fallthrough @@ -263,6 +272,16 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID()) } + if stateFilter.LazyLoadMembers { + delta.StateEvents, err = p.lazyLoadMembers( + ctx, delta.RoomID, true, limited, stateFilter.IncludeRedundantMembers, + device, recentEvents, delta.StateEvents, + ) + if err != nil && err != sql.ErrNoRows { + return r.From, fmt.Errorf("p.lazyLoadMembers: %w", err) + } + } + hasMembershipChange := false for _, recentEvent := range recentStreamEvents { if recentEvent.Type() == gomatrixserverlib.MRoomMember && recentEvent.StateKey() != nil { @@ -322,12 +341,16 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( wantFullState bool, device *userapi.Device, ) (jr *types.JoinResponse, err error) { + jr = types.NewJoinResponse() // TODO: When filters are added, we may need to call this multiple times to get enough events. // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 recentStreamEvents, limited, err := p.DB.RecentEvents( ctx, roomID, r, eventFilter, true, true, ) if err != nil { + if err == sql.ErrNoRows { + return jr, nil + } return } @@ -402,7 +425,20 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( // "Can sync a room with a message with a transaction id" - which does a complete sync to check. recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents) stateEvents = removeDuplicates(stateEvents, recentEvents) - jr = types.NewJoinResponse() + + if stateFilter.LazyLoadMembers { + if err != nil { + return nil, err + } + stateEvents, err = p.lazyLoadMembers(ctx, roomID, + false, limited, stateFilter.IncludeRedundantMembers, + device, recentEvents, stateEvents, + ) + if err != nil && err != sql.ErrNoRows { + return nil, err + } + } + jr.Summary.JoinedMemberCount = &joinedCount jr.Summary.InvitedMemberCount = &invitedCount jr.Timeline.PrevBatch = prevBatch @@ -412,6 +448,69 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( return jr, nil } +func (p *PDUStreamProvider) lazyLoadMembers( + ctx context.Context, roomID string, + incremental, limited, includeRedundant bool, + device *userapi.Device, + timelineEvents, stateEvents []*gomatrixserverlib.HeaderedEvent, +) ([]*gomatrixserverlib.HeaderedEvent, error) { + if len(timelineEvents) == 0 { + return stateEvents, nil + } + // Work out which memberships to include + timelineUsers := make(map[string]struct{}) + if !incremental { + timelineUsers[device.UserID] = struct{}{} + } + // Add all users the client doesn't know about yet to a list + for _, event := range timelineEvents { + // Membership is not yet cached, add it to the list + if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, event.Sender()); !ok { + timelineUsers[event.Sender()] = struct{}{} + } + } + // Preallocate with the same amount, even if it will end up with fewer values + newStateEvents := make([]*gomatrixserverlib.HeaderedEvent, 0, len(stateEvents)) + // Remove existing membership events we don't care about, e.g. users not in the timeline.events + for _, event := range stateEvents { + if event.Type() == gomatrixserverlib.MRoomMember && event.StateKey() != nil { + // If this is a gapped incremental sync, we still want this membership + isGappedIncremental := limited && incremental + // We want this users membership event, keep it in the list + _, ok := timelineUsers[event.Sender()] + wantMembership := ok || isGappedIncremental + if wantMembership { + newStateEvents = append(newStateEvents, event) + if !includeRedundant { + p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, event.Sender(), event.EventID()) + } + delete(timelineUsers, event.Sender()) + } + } else { + newStateEvents = append(newStateEvents, event) + } + } + wantUsers := make([]string, 0, len(timelineUsers)) + for userID := range timelineUsers { + wantUsers = append(wantUsers, userID) + } + // Query missing membership events + memberships, err := p.DB.GetStateEventsForRoom(ctx, roomID, &gomatrixserverlib.StateFilter{ + Limit: 100, + Senders: &wantUsers, + Types: &[]string{gomatrixserverlib.MRoomMember}, + }) + if err != nil { + return stateEvents, err + } + // cache the membership events + for _, membership := range memberships { + p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, membership.Sender(), membership.EventID()) + } + stateEvents = append(newStateEvents, memberships...) + return stateEvents, nil +} + // addIgnoredUsersToFilter adds ignored users to the eventfilter and // the syncreq itself for further use in streams. func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error { @@ -423,8 +522,12 @@ func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *ty return err } req.IgnoredUsers = *ignores + userList := make([]string, 0, len(ignores.List)) for userID := range ignores.List { - eventFilter.NotSenders = append(eventFilter.NotSenders, userID) + userList = append(userList, userID) + } + if len(userList) > 0 { + eventFilter.NotSenders = &userList } return nil } diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go index c7d06a296..d3195b78f 100644 --- a/syncapi/streams/streams.go +++ b/syncapi/streams/streams.go @@ -27,12 +27,12 @@ type Streams struct { func NewSyncStreamProviders( d storage.Database, userAPI userapi.UserInternalAPI, rsAPI rsapi.RoomserverInternalAPI, keyAPI keyapi.KeyInternalAPI, - eduCache *caching.EDUCache, notifier *notifier.Notifier, + eduCache *caching.EDUCache, lazyLoadCache *caching.LazyLoadCache, notifier *notifier.Notifier, ) *Streams { streams := &Streams{ PDUStreamProvider: &PDUStreamProvider{ StreamProvider: StreamProvider{DB: d}, - userAPI: userAPI, + lazyLoadCache: lazyLoadCache, }, TypingStreamProvider: &TypingStreamProvider{ StreamProvider: StreamProvider{DB: d}, diff --git a/syncapi/sync/request.go b/syncapi/sync/request.go index 09a62e3dd..f04f172d3 100644 --- a/syncapi/sync/request.go +++ b/syncapi/sync/request.go @@ -15,6 +15,7 @@ package sync import ( + "database/sql" "encoding/json" "fmt" "net/http" @@ -60,10 +61,10 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") return nil, fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } - if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err != nil { + if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err != nil && err != sql.ErrNoRows { util.GetLogger(req.Context()).WithError(err).Error("syncDB.GetFilter failed") return nil, fmt.Errorf("syncDB.GetFilter: %w", err) - } else { + } else if f != nil { filter = *f } } diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 384121a8a..2f9165d91 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -57,8 +57,12 @@ func AddPublicRoutes( } eduCache := caching.NewTypingCache() + lazyLoadCache, err := caching.NewLazyLoadCache() + if err != nil { + logrus.WithError(err).Panicf("failed to create lazy loading cache") + } notifier := notifier.NewNotifier() - streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, notifier) + streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, lazyLoadCache, notifier) notifier.SetCurrentPosition(streams.Latest(context.Background())) if err = notifier.Load(context.Background(), syncDB); err != nil { logrus.WithError(err).Panicf("failed to load notifier ") diff --git a/sytest-whitelist b/sytest-whitelist index dc67c9935..979f12bf6 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -312,10 +312,10 @@ Inbound federation can return events Inbound federation can return missing events for world_readable visibility Inbound federation can return missing events for invite visibility Inbound federation can get public room list -POST /rooms/:room_id/redact/:event_id as power user redacts message -POST /rooms/:room_id/redact/:event_id as original message sender redacts message -POST /rooms/:room_id/redact/:event_id as random user does not redact message -POST /redact disallows redaction of event in different room +PUT /rooms/:room_id/redact/:event_id/:txn_id as power user redacts message +PUT /rooms/:room_id/redact/:event_id/:txn_id as original message sender redacts message +PUT /rooms/:room_id/redact/:event_id/:txn_id as random user does not redact message +PUT /redact disallows redaction of event in different room An event which redacts itself should be ignored A pair of events which redact each other should be ignored Redaction of a redaction redacts the redaction reason @@ -696,4 +696,17 @@ Room state after a rejected message event is the same as before Room state after a rejected state event is the same as before Ignore user in existing room Ignore invite in full sync -Ignore invite in incremental sync \ No newline at end of file +Ignore invite in incremental sync +A filtered timeline reaches its limit +A change to displayname should not result in a full state sync +Can fetch images in room +The only membership state included in an initial sync is for all the senders in the timeline +The only membership state included in an incremental sync is for senders in the timeline +Old members are included in gappy incr LL sync if they start speaking +We do send redundant membership state across incremental syncs if asked +Rejecting invite over federation doesn't break incremental /sync +Gapped incremental syncs include all state changes +Old leaves are present in gapped incremental syncs +Leaves are present in non-gapped incremental syncs +Members from the gap are included in gappy incr LL sync +Presence can be set from sync \ No newline at end of file diff --git a/test/db.go b/test/db.go index 9deec0a89..6412feaa6 100644 --- a/test/db.go +++ b/test/db.go @@ -15,12 +15,16 @@ package test import ( + "crypto/sha256" "database/sql" + "encoding/hex" "fmt" "os" "os/exec" "os/user" "testing" + + "github.com/lib/pq" ) type DBType int @@ -30,7 +34,7 @@ var DBTypePostgres DBType = 2 var Quiet = false -func createLocalDB(dbName string) string { +func createLocalDB(dbName string) { if !Quiet { fmt.Println("Note: tests require a postgres install accessible to the current user") } @@ -43,7 +47,29 @@ func createLocalDB(dbName string) string { if err != nil && !Quiet { fmt.Println("createLocalDB returned error:", err) } - return dbName +} + +func createRemoteDB(t *testing.T, dbName, user, connStr string) { + db, err := sql.Open("postgres", connStr+" dbname=postgres") + if err != nil { + t.Fatalf("failed to open postgres conn with connstr=%s : %s", connStr, err) + } + _, err = db.Exec(fmt.Sprintf(`CREATE DATABASE %s;`, dbName)) + if err != nil { + pqErr, ok := err.(*pq.Error) + if !ok { + t.Fatalf("failed to CREATE DATABASE: %s", err) + } + // we ignore duplicate database error as we expect this + if pqErr.Code != "42P04" { + t.Fatalf("failed to CREATE DATABASE with code=%s msg=%s", pqErr.Code, pqErr.Message) + } + } + _, err = db.Exec(fmt.Sprintf(`GRANT ALL PRIVILEGES ON DATABASE %s TO %s`, dbName, user)) + if err != nil { + t.Fatalf("failed to GRANT: %s", err) + } + _ = db.Close() } func currentUser() string { @@ -64,6 +90,7 @@ func currentUser() string { // TODO: namespace for concurrent package tests func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, close func()) { if dbType == DBTypeSQLite { + // this will be made in the current working directory which namespaces concurrent package runs correctly dbname := "dendrite_test.db" return fmt.Sprintf("file:%s", dbname), func() { err := os.Remove(dbname) @@ -79,13 +106,9 @@ func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, clo if user == "" { user = currentUser() } - dbName := os.Getenv("POSTGRES_DB") - if dbName == "" { - dbName = createLocalDB("dendrite_test") - } connStr = fmt.Sprintf( - "user=%s dbname=%s sslmode=disable", - user, dbName, + "user=%s sslmode=disable", + user, ) // optional vars, used in CI password := os.Getenv("POSTGRES_PASSWORD") @@ -97,6 +120,25 @@ func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, clo connStr += fmt.Sprintf(" host=%s", host) } + // superuser database + postgresDB := os.Getenv("POSTGRES_DB") + // we cannot use 'dendrite_test' here else 2x concurrently running packages will try to use the same db. + // instead, hash the current working directory, snaffle the first 16 bytes and append that to dendrite_test + // and use that as the unique db name. We do this because packages are per-directory hence by hashing the + // working (test) directory we ensure we get a consistent hash and don't hash against concurrent packages. + wd, err := os.Getwd() + if err != nil { + t.Fatalf("cannot get working directory: %s", err) + } + hash := sha256.Sum256([]byte(wd)) + dbName := fmt.Sprintf("dendrite_test_%s", hex.EncodeToString(hash[:16])) + if postgresDB == "" { // local server, use createdb + createLocalDB(dbName) + } else { // remote server, shell into the postgres user and CREATE DATABASE + createRemoteDB(t, dbName, user, connStr) + } + connStr += fmt.Sprintf(" dbname=%s", dbName) + return connStr, func() { // Drop all tables on the database to get a fresh instance db, err := sql.Open("postgres", connStr) @@ -121,6 +163,7 @@ func WithAllDatabases(t *testing.T, testFn func(t *testing.T, db DBType)) { for dbName, dbType := range dbs { dbt := dbType t.Run(dbName, func(tt *testing.T) { + tt.Parallel() testFn(tt, dbt) }) } diff --git a/test/event.go b/test/event.go index 487b09364..b2e2805ba 100644 --- a/test/event.go +++ b/test/event.go @@ -15,7 +15,9 @@ package test import ( + "bytes" "crypto/ed25519" + "testing" "time" "github.com/matrix-org/gomatrixserverlib" @@ -49,3 +51,40 @@ func WithUnsigned(unsigned interface{}) eventModifier { e.unsigned = unsigned } } + +// Reverse a list of events +func Reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent { + out := make([]*gomatrixserverlib.HeaderedEvent, len(in)) + for i := 0; i < len(in); i++ { + out[i] = in[len(in)-i-1] + } + return out +} + +func AssertEventIDsEqual(t *testing.T, gotEventIDs []string, wants []*gomatrixserverlib.HeaderedEvent) { + t.Helper() + if len(gotEventIDs) != len(wants) { + t.Fatalf("length mismatch: got %d events, want %d", len(gotEventIDs), len(wants)) + } + for i := range wants { + w := wants[i].EventID() + g := gotEventIDs[i] + if w != g { + t.Errorf("event at index %d mismatch:\ngot %s\n\nwant %s", i, string(g), string(w)) + } + } +} + +func AssertEventsEqual(t *testing.T, gots, wants []*gomatrixserverlib.HeaderedEvent) { + t.Helper() + if len(gots) != len(wants) { + t.Fatalf("length mismatch: got %d events, want %d", len(gots), len(wants)) + } + for i := range wants { + w := wants[i].JSON() + g := gots[i].JSON() + if !bytes.Equal(w, g) { + t.Errorf("event at index %d mismatch:\ngot %s\n\nwant %s", i, string(g), string(w)) + } + } +} diff --git a/userapi/api/api.go b/userapi/api/api.go index b86774d14..6aa6a6842 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -492,16 +492,16 @@ type PerformPusherDeletionRequest struct { // Pusher represents a push notification subscriber type Pusher struct { - SessionID int64 `json:"session_id,omitempty"` - PushKey string `json:"pushkey"` - PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"` - Kind PusherKind `json:"kind"` - AppID string `json:"app_id"` - AppDisplayName string `json:"app_display_name"` - DeviceDisplayName string `json:"device_display_name"` - ProfileTag string `json:"profile_tag"` - Language string `json:"lang"` - Data map[string]interface{} `json:"data"` + SessionID int64 `json:"session_id,omitempty"` + PushKey string `json:"pushkey"` + PushKeyTS int64 `json:"pushkey_ts,omitempty"` + Kind PusherKind `json:"kind"` + AppID string `json:"app_id"` + AppDisplayName string `json:"app_display_name"` + DeviceDisplayName string `json:"device_display_name"` + ProfileTag string `json:"profile_tag"` + Language string `json:"lang"` + Data map[string]interface{} `json:"data"` } type PusherKind string diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 206c6f7de..d1c12f05f 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -653,7 +653,7 @@ func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.Perform return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart) } if req.Pusher.PushKeyTS == 0 { - req.Pusher.PushKeyTS = gomatrixserverlib.AsTimestamp(time.Now()) + req.Pusher.PushKeyTS = int64(time.Now().Unix()) } return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart) } diff --git a/userapi/storage/postgres/pusher_table.go b/userapi/storage/postgres/pusher_table.go index 670dc916f..2eb379ae4 100644 --- a/userapi/storage/postgres/pusher_table.go +++ b/userapi/storage/postgres/pusher_table.go @@ -23,7 +23,6 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" ) @@ -95,7 +94,7 @@ type pushersStatements struct { // Returns nil error success. func (s *pushersStatements) InsertPusher( ctx context.Context, txn *sql.Tx, session_id int64, - pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, + pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, ) error { _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) logrus.Debugf("Created pusher %d", session_id) diff --git a/userapi/storage/sqlite3/pusher_table.go b/userapi/storage/sqlite3/pusher_table.go index e718792e1..d5bd1617b 100644 --- a/userapi/storage/sqlite3/pusher_table.go +++ b/userapi/storage/sqlite3/pusher_table.go @@ -23,7 +23,6 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" ) @@ -95,7 +94,7 @@ type pushersStatements struct { // Returns nil error success. func (s *pushersStatements) InsertPusher( ctx context.Context, txn *sql.Tx, session_id int64, - pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, + pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, ) error { _, err := s.insertPusherStmt.ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) logrus.Debugf("Created pusher %d", session_id) diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 99c907b85..eb0cae314 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" ) type AccountDataTable interface { @@ -96,7 +95,7 @@ type ThreePIDTable interface { } type PusherTable interface { - InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error + InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error) DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error