mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-01 03:03:10 -06:00
Merge branch 'main' into neilalexander/removelibp2p
This commit is contained in:
commit
786b89dcd5
|
|
@ -52,6 +52,7 @@ import (
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
"golang.org/x/net/http2/h2c"
|
"golang.org/x/net/http2/h2c"
|
||||||
|
|
||||||
|
pineconeConnections "github.com/matrix-org/pinecone/connections"
|
||||||
pineconeMulticast "github.com/matrix-org/pinecone/multicast"
|
pineconeMulticast "github.com/matrix-org/pinecone/multicast"
|
||||||
pineconeRouter "github.com/matrix-org/pinecone/router"
|
pineconeRouter "github.com/matrix-org/pinecone/router"
|
||||||
pineconeSessions "github.com/matrix-org/pinecone/sessions"
|
pineconeSessions "github.com/matrix-org/pinecone/sessions"
|
||||||
|
|
@ -71,11 +72,9 @@ type DendriteMonolith struct {
|
||||||
PineconeRouter *pineconeRouter.Router
|
PineconeRouter *pineconeRouter.Router
|
||||||
PineconeMulticast *pineconeMulticast.Multicast
|
PineconeMulticast *pineconeMulticast.Multicast
|
||||||
PineconeQUIC *pineconeSessions.Sessions
|
PineconeQUIC *pineconeSessions.Sessions
|
||||||
|
PineconeManager *pineconeConnections.ConnectionManager
|
||||||
StorageDirectory string
|
StorageDirectory string
|
||||||
CacheDirectory string
|
CacheDirectory string
|
||||||
staticPeerURI string
|
|
||||||
staticPeerMutex sync.RWMutex
|
|
||||||
staticPeerAttempt chan struct{}
|
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
httpServer *http.Server
|
httpServer *http.Server
|
||||||
processContext *process.ProcessContext
|
processContext *process.ProcessContext
|
||||||
|
|
@ -104,15 +103,8 @@ func (m *DendriteMonolith) SetMulticastEnabled(enabled bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *DendriteMonolith) SetStaticPeer(uri string) {
|
func (m *DendriteMonolith) SetStaticPeer(uri string) {
|
||||||
m.staticPeerMutex.Lock()
|
m.PineconeManager.RemovePeers()
|
||||||
m.staticPeerURI = strings.TrimSpace(uri)
|
m.PineconeManager.AddPeer(strings.TrimSpace(uri))
|
||||||
m.staticPeerMutex.Unlock()
|
|
||||||
m.DisconnectType(int(pineconeRouter.PeerTypeRemote))
|
|
||||||
if uri != "" {
|
|
||||||
go func() {
|
|
||||||
m.staticPeerAttempt <- struct{}{}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *DendriteMonolith) DisconnectType(peertype int) {
|
func (m *DendriteMonolith) DisconnectType(peertype int) {
|
||||||
|
|
@ -210,43 +202,6 @@ func (m *DendriteMonolith) RegisterDevice(localpart, deviceID string) (string, e
|
||||||
return loginRes.Device.AccessToken, nil
|
return loginRes.Device.AccessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *DendriteMonolith) staticPeerConnect() {
|
|
||||||
connected := map[string]bool{} // URI -> connected?
|
|
||||||
attempt := func() {
|
|
||||||
m.staticPeerMutex.RLock()
|
|
||||||
uri := m.staticPeerURI
|
|
||||||
m.staticPeerMutex.RUnlock()
|
|
||||||
if uri == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for k := range connected {
|
|
||||||
delete(connected, k)
|
|
||||||
}
|
|
||||||
for _, uri := range strings.Split(uri, ",") {
|
|
||||||
connected[strings.TrimSpace(uri)] = false
|
|
||||||
}
|
|
||||||
for _, info := range m.PineconeRouter.Peers() {
|
|
||||||
connected[info.URI] = true
|
|
||||||
}
|
|
||||||
for k, online := range connected {
|
|
||||||
if !online {
|
|
||||||
if err := conn.ConnectToPeer(m.PineconeRouter, k); err != nil {
|
|
||||||
logrus.WithError(err).Error("Failed to connect to static peer")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-m.processContext.Context().Done():
|
|
||||||
case <-m.staticPeerAttempt:
|
|
||||||
attempt()
|
|
||||||
case <-time.After(time.Second * 5):
|
|
||||||
attempt()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// nolint:gocyclo
|
// nolint:gocyclo
|
||||||
func (m *DendriteMonolith) Start() {
|
func (m *DendriteMonolith) Start() {
|
||||||
var err error
|
var err error
|
||||||
|
|
@ -284,6 +239,7 @@ func (m *DendriteMonolith) Start() {
|
||||||
m.PineconeRouter = pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
|
m.PineconeRouter = pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
|
||||||
m.PineconeQUIC = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), m.PineconeRouter, []string{"matrix"})
|
m.PineconeQUIC = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), m.PineconeRouter, []string{"matrix"})
|
||||||
m.PineconeMulticast = pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), m.PineconeRouter)
|
m.PineconeMulticast = pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), m.PineconeRouter)
|
||||||
|
m.PineconeManager = pineconeConnections.NewConnectionManager(m.PineconeRouter)
|
||||||
|
|
||||||
prefix := hex.EncodeToString(pk)
|
prefix := hex.EncodeToString(pk)
|
||||||
cfg := &config.Dendrite{}
|
cfg := &config.Dendrite{}
|
||||||
|
|
@ -392,9 +348,6 @@ func (m *DendriteMonolith) Start() {
|
||||||
|
|
||||||
m.processContext = base.ProcessContext
|
m.processContext = base.ProcessContext
|
||||||
|
|
||||||
m.staticPeerAttempt = make(chan struct{}, 1)
|
|
||||||
go m.staticPeerConnect()
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
m.logger.Info("Listening on ", cfg.Global.ServerName)
|
m.logger.Info("Listening on ", cfg.Global.ServerName)
|
||||||
m.logger.Fatal(m.httpServer.Serve(m.PineconeQUIC.Protocol("matrix")))
|
m.logger.Fatal(m.httpServer.Serve(m.PineconeQUIC.Protocol("matrix")))
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,7 @@ func RequestTurnServer(req *http.Request, device *api.Device, cfg *config.Client
|
||||||
|
|
||||||
if turnConfig.SharedSecret != "" {
|
if turnConfig.SharedSecret != "" {
|
||||||
expiry := time.Now().Add(duration).Unix()
|
expiry := time.Now().Add(duration).Unix()
|
||||||
|
resp.Username = fmt.Sprintf("%d:%s", expiry, device.UserID)
|
||||||
mac := hmac.New(sha1.New, []byte(turnConfig.SharedSecret))
|
mac := hmac.New(sha1.New, []byte(turnConfig.SharedSecret))
|
||||||
_, err := mac.Write([]byte(resp.Username))
|
_, err := mac.Write([]byte(resp.Username))
|
||||||
|
|
||||||
|
|
@ -60,7 +61,6 @@ func RequestTurnServer(req *http.Request, device *api.Device, cfg *config.Client
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.Username = fmt.Sprintf("%d:%s", expiry, device.UserID)
|
|
||||||
resp.Password = base64.StdEncoding.EncodeToString(mac.Sum(nil))
|
resp.Password = base64.StdEncoding.EncodeToString(mac.Sum(nil))
|
||||||
} else if turnConfig.Username != "" && turnConfig.Password != "" {
|
} else if turnConfig.Username != "" && turnConfig.Password != "" {
|
||||||
resp.Username = turnConfig.Username
|
resp.Username = turnConfig.Username
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,6 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
@ -47,6 +46,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/userapi"
|
"github.com/matrix-org/dendrite/userapi"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
|
pineconeConnections "github.com/matrix-org/pinecone/connections"
|
||||||
pineconeMulticast "github.com/matrix-org/pinecone/multicast"
|
pineconeMulticast "github.com/matrix-org/pinecone/multicast"
|
||||||
pineconeRouter "github.com/matrix-org/pinecone/router"
|
pineconeRouter "github.com/matrix-org/pinecone/router"
|
||||||
pineconeSessions "github.com/matrix-org/pinecone/sessions"
|
pineconeSessions "github.com/matrix-org/pinecone/sessions"
|
||||||
|
|
@ -90,6 +90,13 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
|
pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
|
||||||
|
pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
|
||||||
|
pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter)
|
||||||
|
pManager := pineconeConnections.NewConnectionManager(pRouter)
|
||||||
|
pMulticast.Start()
|
||||||
|
if instancePeer != nil && *instancePeer != "" {
|
||||||
|
pManager.AddPeer(*instancePeer)
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
listener, err := net.Listen("tcp", *instanceListen)
|
listener, err := net.Listen("tcp", *instanceListen)
|
||||||
|
|
@ -119,36 +126,6 @@ func main() {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
|
|
||||||
pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter)
|
|
||||||
pMulticast.Start()
|
|
||||||
|
|
||||||
connectToStaticPeer := func() {
|
|
||||||
connected := map[string]bool{} // URI -> connected?
|
|
||||||
for _, uri := range strings.Split(*instancePeer, ",") {
|
|
||||||
connected[strings.TrimSpace(uri)] = false
|
|
||||||
}
|
|
||||||
attempt := func() {
|
|
||||||
for k := range connected {
|
|
||||||
connected[k] = false
|
|
||||||
}
|
|
||||||
for _, info := range pRouter.Peers() {
|
|
||||||
connected[info.URI] = true
|
|
||||||
}
|
|
||||||
for k, online := range connected {
|
|
||||||
if !online {
|
|
||||||
if err := conn.ConnectToPeer(pRouter, k); err != nil {
|
|
||||||
logrus.WithError(err).Error("Failed to connect to static peer")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
attempt()
|
|
||||||
time.Sleep(time.Second * 5)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg := &config.Dendrite{}
|
cfg := &config.Dendrite{}
|
||||||
cfg.Defaults(true)
|
cfg.Defaults(true)
|
||||||
cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk))
|
cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk))
|
||||||
|
|
@ -268,7 +245,6 @@ func main() {
|
||||||
Handler: pMux,
|
Handler: pMux,
|
||||||
}
|
}
|
||||||
|
|
||||||
go connectToStaticPeer()
|
|
||||||
go func() {
|
go func() {
|
||||||
pubkey := pRouter.PublicKey()
|
pubkey := pRouter.PublicKey()
|
||||||
logrus.Info("Listening on ", hex.EncodeToString(pubkey[:]))
|
logrus.Info("Listening on ", hex.EncodeToString(pubkey[:]))
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,6 @@ import (
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"syscall/js"
|
"syscall/js"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/matrix-org/dendrite/appservice"
|
"github.com/matrix-org/dendrite/appservice"
|
||||||
|
|
@ -44,6 +43,7 @@ import (
|
||||||
|
|
||||||
_ "github.com/matrix-org/go-sqlite3-js"
|
_ "github.com/matrix-org/go-sqlite3-js"
|
||||||
|
|
||||||
|
pineconeConnections "github.com/matrix-org/pinecone/connections"
|
||||||
pineconeRouter "github.com/matrix-org/pinecone/router"
|
pineconeRouter "github.com/matrix-org/pinecone/router"
|
||||||
pineconeSessions "github.com/matrix-org/pinecone/sessions"
|
pineconeSessions "github.com/matrix-org/pinecone/sessions"
|
||||||
)
|
)
|
||||||
|
|
@ -154,6 +154,8 @@ func startup() {
|
||||||
|
|
||||||
pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
|
pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
|
||||||
pSessions := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
|
pSessions := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
|
||||||
|
pManager := pineconeConnections.NewConnectionManager(pRouter)
|
||||||
|
pManager.AddPeer("wss://pinecone.matrix.org/public")
|
||||||
|
|
||||||
cfg := &config.Dendrite{}
|
cfg := &config.Dendrite{}
|
||||||
cfg.Defaults(true)
|
cfg.Defaults(true)
|
||||||
|
|
@ -237,20 +239,4 @@ func startup() {
|
||||||
}
|
}
|
||||||
s.ListenAndServe("fetch")
|
s.ListenAndServe("fetch")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Connect to the static peer
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
if pRouter.PeerCount(pineconeRouter.PeerTypeRemote) == 0 {
|
|
||||||
if err := conn.ConnectToPeer(pRouter, publicPeer); err != nil {
|
|
||||||
logrus.WithError(err).Error("Failed to connect to static peer")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-base.ProcessContext.Context().Done():
|
|
||||||
return
|
|
||||||
case <-time.After(time.Second * 5):
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
13
go.mod
13
go.mod
|
|
@ -1,8 +1,8 @@
|
||||||
module github.com/matrix-org/dendrite
|
module github.com/matrix-org/dendrite
|
||||||
|
|
||||||
replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.7.5-0.20220311134712-e2e4a244f30e
|
replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.8.1-0.20220419100629-2278c94774f9
|
||||||
|
|
||||||
replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c
|
replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.13.1-0.20220419101051-b262d9f0be1e
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/Arceliar/ironwood v0.0.0-20211125050254-8951369625d0
|
github.com/Arceliar/ironwood v0.0.0-20211125050254-8951369625d0
|
||||||
|
|
@ -27,15 +27,15 @@ require (
|
||||||
github.com/juju/testing v0.0.0-20220203020004-a0ff61f03494 // indirect
|
github.com/juju/testing v0.0.0-20220203020004-a0ff61f03494 // indirect
|
||||||
github.com/lib/pq v1.10.5
|
github.com/lib/pq v1.10.5
|
||||||
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
|
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
|
||||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d
|
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
|
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20220408160933-cf558306b56f
|
||||||
github.com/matrix-org/pinecone v0.0.0-20220404141326-e526fa82f79d
|
github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48
|
||||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
||||||
github.com/mattn/go-sqlite3 v1.14.10
|
github.com/mattn/go-sqlite3 v1.14.10
|
||||||
github.com/miekg/dns v1.1.31 // indirect
|
github.com/miekg/dns v1.1.31 // indirect
|
||||||
github.com/nats-io/nats-server/v2 v2.7.4-0.20220309205833-773636c1c5bb
|
github.com/nats-io/nats-server/v2 v2.7.4-0.20220309205833-773636c1c5bb
|
||||||
github.com/nats-io/nats.go v1.13.1-0.20220308171302-2f2f6968e98d
|
github.com/nats-io/nats.go v1.14.0
|
||||||
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9
|
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9
|
||||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
||||||
github.com/ngrok/sqlmw v0.0.0-20211220175533-9d16fdc47b31
|
github.com/ngrok/sqlmw v0.0.0-20211220175533-9d16fdc47b31
|
||||||
|
|
@ -56,6 +56,7 @@ require (
|
||||||
golang.org/x/image v0.0.0-20220321031419-a8550c1d254a
|
golang.org/x/image v0.0.0-20220321031419-a8550c1d254a
|
||||||
golang.org/x/mobile v0.0.0-20220407111146-e579adbbc4a2
|
golang.org/x/mobile v0.0.0-20220407111146-e579adbbc4a2
|
||||||
golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3
|
golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3
|
||||||
|
golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 // indirect
|
||||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211
|
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211
|
||||||
gopkg.in/h2non/bimg.v1 v1.1.9
|
gopkg.in/h2non/bimg.v1 v1.1.9
|
||||||
gopkg.in/yaml.v2 v2.4.0
|
gopkg.in/yaml.v2 v2.4.0
|
||||||
|
|
|
||||||
27
go.sum
27
go.sum
|
|
@ -789,15 +789,15 @@ github.com/masterzen/winrm v0.0.0-20161014151040-7a535cd943fc/go.mod h1:CfZSN7zw
|
||||||
github.com/masterzen/xmlpath v0.0.0-20140218185901-13f4951698ad/go.mod h1:A0zPC53iKKKcXYxr4ROjpQRQ5FgJXtelNdSmHHuq/tY=
|
github.com/masterzen/xmlpath v0.0.0-20140218185901-13f4951698ad/go.mod h1:A0zPC53iKKKcXYxr4ROjpQRQ5FgJXtelNdSmHHuq/tY=
|
||||||
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e h1:DP5RC0Z3XdyBEW5dKt8YPeN6vZbm6OzVaGVp7f1BQRM=
|
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e h1:DP5RC0Z3XdyBEW5dKt8YPeN6vZbm6OzVaGVp7f1BQRM=
|
||||||
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e/go.mod h1:NgPCr+UavRGH6n5jmdX8DuqFZ4JiCWIJoZiuhTRLSUg=
|
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e/go.mod h1:NgPCr+UavRGH6n5jmdX8DuqFZ4JiCWIJoZiuhTRLSUg=
|
||||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d h1:mGhPVaTht5NViFN/UpdrIlRApmH2FWcVaKUH5MdBKiY=
|
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw2QV3YD/fRrzEDPNGgTlJlvXY0EHHnT87wF3OA=
|
||||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
|
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
|
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
|
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5 h1:Fkennny7+Z/5pygrhjFMZbz1j++P2hhhLoT7NO3p8DQ=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20220408160933-cf558306b56f h1:MZrl4TgTnlaOn2Cu9gJCoJ3oyW5mT4/3QIZGgZXzKl4=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5/go.mod h1:V5eO8rn/C3rcxig37A/BCeKerLFS+9Avg/77FIeTZ48=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20220408160933-cf558306b56f/go.mod h1:V5eO8rn/C3rcxig37A/BCeKerLFS+9Avg/77FIeTZ48=
|
||||||
github.com/matrix-org/pinecone v0.0.0-20220404141326-e526fa82f79d h1:1+T4eOPRsf6cr0lMPW4oO2k8TTHm4mqIh65kpEID5Rk=
|
github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48 h1:W0sjjC6yjskHX4mb0nk3p0fXAlbU5bAFUFeEtlrPASE=
|
||||||
github.com/matrix-org/pinecone v0.0.0-20220404141326-e526fa82f79d/go.mod h1:ulJzsVOTssIVp1j/m5eI//4VpAGDkMt5NrRuAVX7wpc=
|
github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48/go.mod h1:ulJzsVOTssIVp1j/m5eI//4VpAGDkMt5NrRuAVX7wpc=
|
||||||
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
||||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
|
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
|
||||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
||||||
|
|
@ -878,8 +878,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8m
|
||||||
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
||||||
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
||||||
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
|
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
|
||||||
github.com/nats-io/jwt/v2 v2.2.1-0.20220113022732-58e87895b296 h1:vU9tpM3apjYlLLeY23zRWJ9Zktr5jp+mloR942LEOpY=
|
github.com/nats-io/jwt/v2 v2.2.1-0.20220330180145-442af02fd36a h1:lem6QCvxR0Y28gth9P+wV2K/zYUUAkJ+55U8cpS0p5I=
|
||||||
github.com/nats-io/jwt/v2 v2.2.1-0.20220113022732-58e87895b296/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k=
|
github.com/nats-io/jwt/v2 v2.2.1-0.20220330180145-442af02fd36a/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k=
|
||||||
github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8=
|
github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8=
|
||||||
github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4=
|
github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4=
|
||||||
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
|
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
|
||||||
|
|
@ -888,10 +888,10 @@ github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uY
|
||||||
github.com/ncw/swift v1.0.47/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM=
|
github.com/ncw/swift v1.0.47/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM=
|
||||||
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo=
|
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo=
|
||||||
github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
|
github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
|
||||||
github.com/neilalexander/nats-server/v2 v2.7.5-0.20220311134712-e2e4a244f30e h1:5tEHLzvDeS6IeqO2o9FFhsE3V2erYj8FlMt2J91wzsk=
|
github.com/neilalexander/nats-server/v2 v2.8.1-0.20220419100629-2278c94774f9 h1:VGU5HYAwy8LRbSkrT+kCHvujVmwK8Aa/vc1O+eReTbM=
|
||||||
github.com/neilalexander/nats-server/v2 v2.7.5-0.20220311134712-e2e4a244f30e/go.mod h1:1vZ2Nijh8tcyNe8BDVyTviCd9NYzRbubQYiEHsvOQWc=
|
github.com/neilalexander/nats-server/v2 v2.8.1-0.20220419100629-2278c94774f9/go.mod h1:5vic7C58BFEVltiZhs7Kq81q2WcEPhJPsmNv1FOrdv0=
|
||||||
github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c h1:G2qsv7D0rY94HAu8pXmElMluuMHQ85waxIDQBhIzV2Q=
|
github.com/neilalexander/nats.go v1.13.1-0.20220419101051-b262d9f0be1e h1:kNIzIzj2OvnlreA+sTJ12nWJzTP3OSLNKDL/Iq9mF6Y=
|
||||||
github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w=
|
github.com/neilalexander/nats.go v1.13.1-0.20220419101051-b262d9f0be1e/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w=
|
||||||
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 h1:lrVQzBtkeQEGGYUHwSX1XPe1E5GL6U3KYCNe2G4bncQ=
|
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 h1:lrVQzBtkeQEGGYUHwSX1XPe1E5GL6U3KYCNe2G4bncQ=
|
||||||
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8=
|
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8=
|
||||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
||||||
|
|
@ -1541,8 +1541,9 @@ golang.org/x/sys v0.0.0-20211116061358-0a5406a5449c/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||||
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64 h1:D1v9ucDTYBtbz5vNuBbAhIMAGhQhJ6Ym5ah3maMVNX4=
|
|
||||||
golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 h1:QyVthZKMsyaQwBTJE04jdNN0Pp5Fn9Qga0mrgxyERQM=
|
||||||
|
golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
|
|
|
||||||
86
internal/caching/cache_lazy_load_members.go
Normal file
86
internal/caching/cache_lazy_load_members.go
Normal file
|
|
@ -0,0 +1,86 @@
|
||||||
|
package caching
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
LazyLoadCacheName = "lazy_load_members"
|
||||||
|
LazyLoadCacheMaxEntries = 128
|
||||||
|
LazyLoadCacheMaxUserEntries = 128
|
||||||
|
LazyLoadCacheMutable = true
|
||||||
|
LazyLoadCacheMaxAge = time.Minute * 30
|
||||||
|
)
|
||||||
|
|
||||||
|
type LazyLoadCache struct {
|
||||||
|
// InMemoryLRUCachePartition containing other InMemoryLRUCachePartitions
|
||||||
|
// with the actual cached members
|
||||||
|
userCaches *InMemoryLRUCachePartition
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLazyLoadCache creates a new LazyLoadCache.
|
||||||
|
func NewLazyLoadCache() (*LazyLoadCache, error) {
|
||||||
|
cache, err := NewInMemoryLRUCachePartition(
|
||||||
|
LazyLoadCacheName,
|
||||||
|
LazyLoadCacheMutable,
|
||||||
|
LazyLoadCacheMaxEntries,
|
||||||
|
LazyLoadCacheMaxAge,
|
||||||
|
true,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
go cacheCleaner(cache)
|
||||||
|
return &LazyLoadCache{
|
||||||
|
userCaches: cache,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LazyLoadCache) lazyLoadCacheForUser(device *userapi.Device) (*InMemoryLRUCachePartition, error) {
|
||||||
|
cacheName := fmt.Sprintf("%s/%s", device.UserID, device.ID)
|
||||||
|
userCache, ok := c.userCaches.Get(cacheName)
|
||||||
|
if ok && userCache != nil {
|
||||||
|
if cache, ok := userCache.(*InMemoryLRUCachePartition); ok {
|
||||||
|
return cache, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cache, err := NewInMemoryLRUCachePartition(
|
||||||
|
LazyLoadCacheName,
|
||||||
|
LazyLoadCacheMutable,
|
||||||
|
LazyLoadCacheMaxUserEntries,
|
||||||
|
LazyLoadCacheMaxAge,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
c.userCaches.Set(cacheName, cache)
|
||||||
|
go cacheCleaner(cache)
|
||||||
|
return cache, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LazyLoadCache) StoreLazyLoadedUser(device *userapi.Device, roomID, userID, eventID string) {
|
||||||
|
cache, err := c.lazyLoadCacheForUser(device)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cacheKey := fmt.Sprintf("%s/%s/%s/%s", device.UserID, device.ID, roomID, userID)
|
||||||
|
cache.Set(cacheKey, eventID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LazyLoadCache) IsLazyLoadedUserCached(device *userapi.Device, roomID, userID string) (string, bool) {
|
||||||
|
cache, err := c.lazyLoadCacheForUser(device)
|
||||||
|
if err != nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheKey := fmt.Sprintf("%s/%s/%s/%s", device.UserID, device.ID, roomID, userID)
|
||||||
|
val, ok := cache.Get(cacheKey)
|
||||||
|
if !ok {
|
||||||
|
return "", ok
|
||||||
|
}
|
||||||
|
return val.(string), ok
|
||||||
|
}
|
||||||
|
|
@ -3,8 +3,6 @@ package pushgateway
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// A Client is how interactions with a Push Gateway is done.
|
// A Client is how interactions with a Push Gateway is done.
|
||||||
|
|
@ -47,11 +45,11 @@ type Counts struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Device struct {
|
type Device struct {
|
||||||
AppID string `json:"app_id"` // Required
|
AppID string `json:"app_id"` // Required
|
||||||
Data map[string]interface{} `json:"data"` // Required. UNSPEC: Sytests require this to allow unknown keys.
|
Data map[string]interface{} `json:"data"` // Required. UNSPEC: Sytests require this to allow unknown keys.
|
||||||
PushKey string `json:"pushkey"` // Required
|
PushKey string `json:"pushkey"` // Required
|
||||||
PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"`
|
PushKeyTS int64 `json:"pushkey_ts,omitempty"`
|
||||||
Tweaks map[string]interface{} `json:"tweaks,omitempty"`
|
Tweaks map[string]interface{} `json:"tweaks,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Prio string
|
type Prio string
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ func AddPublicRoutes(
|
||||||
userAPI userapi.UserInternalAPI,
|
userAPI userapi.UserInternalAPI,
|
||||||
client *gomatrixserverlib.Client,
|
client *gomatrixserverlib.Client,
|
||||||
) {
|
) {
|
||||||
mediaDB, err := storage.Open(&cfg.Database)
|
mediaDB, err := storage.NewMediaAPIDatasource(&cfg.Database)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Panicf("failed to connect to media db")
|
logrus.WithError(err).Panicf("failed to connect to media db")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
|
@ -311,6 +312,26 @@ func (r *uploadRequest) storeFileAndMetadata(
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
file, err := os.Open(string(finalPath))
|
||||||
|
if err != nil {
|
||||||
|
r.Logger.WithError(err).Error("unable to open file")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer file.Close() // nolint: errcheck
|
||||||
|
// http.DetectContentType only needs 512 bytes
|
||||||
|
buf := make([]byte, 512)
|
||||||
|
_, err = file.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
r.Logger.WithError(err).Error("unable to read file")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Check if we need to generate thumbnails
|
||||||
|
fileType := http.DetectContentType(buf)
|
||||||
|
if !strings.HasPrefix(fileType, "image") {
|
||||||
|
r.Logger.WithField("contentType", fileType).Debugf("uploaded file is not an image or can not be thumbnailed, not generating thumbnails")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
busy, err := thumbnailer.GenerateThumbnails(
|
busy, err := thumbnailer.GenerateThumbnails(
|
||||||
context.Background(), finalPath, thumbnailSizes, r.MediaMetadata,
|
context.Background(), finalPath, thumbnailSizes, r.MediaMetadata,
|
||||||
activeThumbnailGeneration, maxThumbnailGenerators, db, r.Logger,
|
activeThumbnailGeneration, maxThumbnailGenerators, db, r.Logger,
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ func Test_uploadRequest_doUpload(t *testing.T) {
|
||||||
_ = os.Mkdir(testdataPath, os.ModePerm)
|
_ = os.Mkdir(testdataPath, os.ModePerm)
|
||||||
defer fileutils.RemoveDir(types.Path(testdataPath), nil)
|
defer fileutils.RemoveDir(types.Path(testdataPath), nil)
|
||||||
|
|
||||||
db, err := storage.Open(&config.DatabaseOptions{
|
db, err := storage.NewMediaAPIDatasource(&config.DatabaseOptions{
|
||||||
ConnectionString: "file::memory:?cache=shared",
|
ConnectionString: "file::memory:?cache=shared",
|
||||||
MaxOpenConnections: 100,
|
MaxOpenConnections: 100,
|
||||||
MaxIdleConnections: 2,
|
MaxIdleConnections: 2,
|
||||||
|
|
|
||||||
|
|
@ -22,9 +22,17 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Database interface {
|
type Database interface {
|
||||||
|
MediaRepository
|
||||||
|
Thumbnails
|
||||||
|
}
|
||||||
|
|
||||||
|
type MediaRepository interface {
|
||||||
StoreMediaMetadata(ctx context.Context, mediaMetadata *types.MediaMetadata) error
|
StoreMediaMetadata(ctx context.Context, mediaMetadata *types.MediaMetadata) error
|
||||||
GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error)
|
GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error)
|
||||||
GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error)
|
GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Thumbnails interface {
|
||||||
StoreThumbnail(ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata) error
|
StoreThumbnail(ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata) error
|
||||||
GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error)
|
GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error)
|
||||||
GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) ([]*types.ThumbnailMetadata, error)
|
GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) ([]*types.ThumbnailMetadata, error)
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,8 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
@ -69,24 +71,25 @@ type mediaStatements struct {
|
||||||
selectMediaByHashStmt *sql.Stmt
|
selectMediaByHashStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mediaStatements) prepare(db *sql.DB) (err error) {
|
func NewPostgresMediaRepositoryTable(db *sql.DB) (tables.MediaRepository, error) {
|
||||||
_, err = db.Exec(mediaSchema)
|
s := &mediaStatements{}
|
||||||
|
_, err := db.Exec(mediaSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return statementList{
|
return s, sqlutil.StatementList{
|
||||||
{&s.insertMediaStmt, insertMediaSQL},
|
{&s.insertMediaStmt, insertMediaSQL},
|
||||||
{&s.selectMediaStmt, selectMediaSQL},
|
{&s.selectMediaStmt, selectMediaSQL},
|
||||||
{&s.selectMediaByHashStmt, selectMediaByHashSQL},
|
{&s.selectMediaByHashStmt, selectMediaByHashSQL},
|
||||||
}.prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mediaStatements) insertMedia(
|
func (s *mediaStatements) InsertMedia(
|
||||||
ctx context.Context, mediaMetadata *types.MediaMetadata,
|
ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata,
|
||||||
) error {
|
) error {
|
||||||
mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
|
mediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
|
||||||
_, err := s.insertMediaStmt.ExecContext(
|
_, err := sqlutil.TxStmtContext(ctx, txn, s.insertMediaStmt).ExecContext(
|
||||||
ctx,
|
ctx,
|
||||||
mediaMetadata.MediaID,
|
mediaMetadata.MediaID,
|
||||||
mediaMetadata.Origin,
|
mediaMetadata.Origin,
|
||||||
|
|
@ -100,14 +103,14 @@ func (s *mediaStatements) insertMedia(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mediaStatements) selectMedia(
|
func (s *mediaStatements) SelectMedia(
|
||||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||||
) (*types.MediaMetadata, error) {
|
) (*types.MediaMetadata, error) {
|
||||||
mediaMetadata := types.MediaMetadata{
|
mediaMetadata := types.MediaMetadata{
|
||||||
MediaID: mediaID,
|
MediaID: mediaID,
|
||||||
Origin: mediaOrigin,
|
Origin: mediaOrigin,
|
||||||
}
|
}
|
||||||
err := s.selectMediaStmt.QueryRowContext(
|
err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaStmt).QueryRowContext(
|
||||||
ctx, mediaMetadata.MediaID, mediaMetadata.Origin,
|
ctx, mediaMetadata.MediaID, mediaMetadata.Origin,
|
||||||
).Scan(
|
).Scan(
|
||||||
&mediaMetadata.ContentType,
|
&mediaMetadata.ContentType,
|
||||||
|
|
@ -120,14 +123,14 @@ func (s *mediaStatements) selectMedia(
|
||||||
return &mediaMetadata, err
|
return &mediaMetadata, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mediaStatements) selectMediaByHash(
|
func (s *mediaStatements) SelectMediaByHash(
|
||||||
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
|
ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
|
||||||
) (*types.MediaMetadata, error) {
|
) (*types.MediaMetadata, error) {
|
||||||
mediaMetadata := types.MediaMetadata{
|
mediaMetadata := types.MediaMetadata{
|
||||||
Base64Hash: mediaHash,
|
Base64Hash: mediaHash,
|
||||||
Origin: mediaOrigin,
|
Origin: mediaOrigin,
|
||||||
}
|
}
|
||||||
err := s.selectMediaStmt.QueryRowContext(
|
err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaByHashStmt).QueryRowContext(
|
||||||
ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin,
|
ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin,
|
||||||
).Scan(
|
).Scan(
|
||||||
&mediaMetadata.ContentType,
|
&mediaMetadata.ContentType,
|
||||||
|
|
|
||||||
46
mediaapi/storage/postgres/mediaapi.go
Normal file
46
mediaapi/storage/postgres/mediaapi.go
Normal file
|
|
@ -0,0 +1,46 @@
|
||||||
|
// Copyright 2017-2018 New Vector Ltd
|
||||||
|
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
// Import the postgres database driver.
|
||||||
|
_ "github.com/lib/pq"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/mediaapi/storage/shared"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewDatabase opens a postgres database.
|
||||||
|
func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) {
|
||||||
|
db, err := sqlutil.Open(dbProperties)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
mediaRepo, err := NewPostgresMediaRepositoryTable(db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
thumbnails, err := NewPostgresThumbnailsTable(db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &shared.Database{
|
||||||
|
MediaRepository: mediaRepo,
|
||||||
|
Thumbnails: thumbnails,
|
||||||
|
DB: db,
|
||||||
|
Writer: sqlutil.NewExclusiveWriter(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
@ -1,38 +0,0 @@
|
||||||
// Copyright 2017-2018 New Vector Ltd
|
|
||||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
// FIXME: This should be made internal!
|
|
||||||
|
|
||||||
package postgres
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement.
|
|
||||||
type statementList []struct {
|
|
||||||
statement **sql.Stmt
|
|
||||||
sql string
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepare the SQL for each statement in the list and assign the result to the prepared statement.
|
|
||||||
func (s statementList) prepare(db *sql.DB) (err error) {
|
|
||||||
for _, statement := range s {
|
|
||||||
if *statement.statement, err = db.Prepare(statement.sql); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
@ -1,36 +0,0 @@
|
||||||
// Copyright 2017-2018 New Vector Ltd
|
|
||||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package postgres
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
)
|
|
||||||
|
|
||||||
type statements struct {
|
|
||||||
media mediaStatements
|
|
||||||
thumbnail thumbnailStatements
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *statements) prepare(db *sql.DB) (err error) {
|
|
||||||
if err = s.media.prepare(db); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err = s.thumbnail.prepare(db); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
@ -21,6 +21,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
@ -63,7 +65,7 @@ SELECT content_type, file_size_bytes, creation_ts FROM mediaapi_thumbnail WHERE
|
||||||
|
|
||||||
// Note: this selects all thumbnails for a media_origin and media_id
|
// Note: this selects all thumbnails for a media_origin and media_id
|
||||||
const selectThumbnailsSQL = `
|
const selectThumbnailsSQL = `
|
||||||
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2
|
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 ORDER BY creation_ts ASC
|
||||||
`
|
`
|
||||||
|
|
||||||
type thumbnailStatements struct {
|
type thumbnailStatements struct {
|
||||||
|
|
@ -72,24 +74,25 @@ type thumbnailStatements struct {
|
||||||
selectThumbnailsStmt *sql.Stmt
|
selectThumbnailsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *thumbnailStatements) prepare(db *sql.DB) (err error) {
|
func NewPostgresThumbnailsTable(db *sql.DB) (tables.Thumbnails, error) {
|
||||||
_, err = db.Exec(thumbnailSchema)
|
s := &thumbnailStatements{}
|
||||||
|
_, err := db.Exec(thumbnailSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return statementList{
|
return s, sqlutil.StatementList{
|
||||||
{&s.insertThumbnailStmt, insertThumbnailSQL},
|
{&s.insertThumbnailStmt, insertThumbnailSQL},
|
||||||
{&s.selectThumbnailStmt, selectThumbnailSQL},
|
{&s.selectThumbnailStmt, selectThumbnailSQL},
|
||||||
{&s.selectThumbnailsStmt, selectThumbnailsSQL},
|
{&s.selectThumbnailsStmt, selectThumbnailsSQL},
|
||||||
}.prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *thumbnailStatements) insertThumbnail(
|
func (s *thumbnailStatements) InsertThumbnail(
|
||||||
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
|
ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata,
|
||||||
) error {
|
) error {
|
||||||
thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
|
thumbnailMetadata.MediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
|
||||||
_, err := s.insertThumbnailStmt.ExecContext(
|
_, err := sqlutil.TxStmtContext(ctx, txn, s.insertThumbnailStmt).ExecContext(
|
||||||
ctx,
|
ctx,
|
||||||
thumbnailMetadata.MediaMetadata.MediaID,
|
thumbnailMetadata.MediaMetadata.MediaID,
|
||||||
thumbnailMetadata.MediaMetadata.Origin,
|
thumbnailMetadata.MediaMetadata.Origin,
|
||||||
|
|
@ -103,8 +106,9 @@ func (s *thumbnailStatements) insertThumbnail(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *thumbnailStatements) selectThumbnail(
|
func (s *thumbnailStatements) SelectThumbnail(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
txn *sql.Tx,
|
||||||
mediaID types.MediaID,
|
mediaID types.MediaID,
|
||||||
mediaOrigin gomatrixserverlib.ServerName,
|
mediaOrigin gomatrixserverlib.ServerName,
|
||||||
width, height int,
|
width, height int,
|
||||||
|
|
@ -121,7 +125,7 @@ func (s *thumbnailStatements) selectThumbnail(
|
||||||
ResizeMethod: resizeMethod,
|
ResizeMethod: resizeMethod,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
err := s.selectThumbnailStmt.QueryRowContext(
|
err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailStmt).QueryRowContext(
|
||||||
ctx,
|
ctx,
|
||||||
thumbnailMetadata.MediaMetadata.MediaID,
|
thumbnailMetadata.MediaMetadata.MediaID,
|
||||||
thumbnailMetadata.MediaMetadata.Origin,
|
thumbnailMetadata.MediaMetadata.Origin,
|
||||||
|
|
@ -136,10 +140,10 @@ func (s *thumbnailStatements) selectThumbnail(
|
||||||
return &thumbnailMetadata, err
|
return &thumbnailMetadata, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *thumbnailStatements) selectThumbnails(
|
func (s *thumbnailStatements) SelectThumbnails(
|
||||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||||
) ([]*types.ThumbnailMetadata, error) {
|
) ([]*types.ThumbnailMetadata, error) {
|
||||||
rows, err := s.selectThumbnailsStmt.QueryContext(
|
rows, err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailsStmt).QueryContext(
|
||||||
ctx, mediaID, mediaOrigin,
|
ctx, mediaID, mediaOrigin,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
// Copyright 2017-2018 New Vector Ltd
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
|
|
@ -13,54 +12,38 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package postgres
|
package shared
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
// Import the postgres database driver.
|
|
||||||
_ "github.com/lib/pq"
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Database is used to store metadata about a repository of media files.
|
|
||||||
type Database struct {
|
type Database struct {
|
||||||
statements statements
|
DB *sql.DB
|
||||||
db *sql.DB
|
Writer sqlutil.Writer
|
||||||
}
|
MediaRepository tables.MediaRepository
|
||||||
|
Thumbnails tables.Thumbnails
|
||||||
// Open opens a postgres database.
|
|
||||||
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
|
|
||||||
var d Database
|
|
||||||
var err error
|
|
||||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = d.statements.prepare(d.db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &d, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// StoreMediaMetadata inserts the metadata about the uploaded media into the database.
|
// StoreMediaMetadata inserts the metadata about the uploaded media into the database.
|
||||||
// Returns an error if the combination of MediaID and Origin are not unique in the table.
|
// Returns an error if the combination of MediaID and Origin are not unique in the table.
|
||||||
func (d *Database) StoreMediaMetadata(
|
func (d Database) StoreMediaMetadata(ctx context.Context, mediaMetadata *types.MediaMetadata) error {
|
||||||
ctx context.Context, mediaMetadata *types.MediaMetadata,
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
) error {
|
return d.MediaRepository.InsertMedia(ctx, txn, mediaMetadata)
|
||||||
return d.statements.media.insertMedia(ctx, mediaMetadata)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMediaMetadata returns metadata about media stored on this server.
|
// GetMediaMetadata returns metadata about media stored on this server.
|
||||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
// The media could have been uploaded to this server or fetched from another server and cached here.
|
||||||
// Returns nil metadata if there is no metadata associated with this media.
|
// Returns nil metadata if there is no metadata associated with this media.
|
||||||
func (d *Database) GetMediaMetadata(
|
func (d Database) GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) {
|
||||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
mediaMetadata, err := d.MediaRepository.SelectMedia(ctx, nil, mediaID, mediaOrigin)
|
||||||
) (*types.MediaMetadata, error) {
|
|
||||||
mediaMetadata, err := d.statements.media.selectMedia(ctx, mediaID, mediaOrigin)
|
|
||||||
if err != nil && err == sql.ErrNoRows {
|
if err != nil && err == sql.ErrNoRows {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
@ -70,10 +53,8 @@ func (d *Database) GetMediaMetadata(
|
||||||
// GetMediaMetadataByHash returns metadata about media stored on this server.
|
// GetMediaMetadataByHash returns metadata about media stored on this server.
|
||||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
// The media could have been uploaded to this server or fetched from another server and cached here.
|
||||||
// Returns nil metadata if there is no metadata associated with this media.
|
// Returns nil metadata if there is no metadata associated with this media.
|
||||||
func (d *Database) GetMediaMetadataByHash(
|
func (d Database) GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) {
|
||||||
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
|
mediaMetadata, err := d.MediaRepository.SelectMediaByHash(ctx, nil, mediaHash, mediaOrigin)
|
||||||
) (*types.MediaMetadata, error) {
|
|
||||||
mediaMetadata, err := d.statements.media.selectMediaByHash(ctx, mediaHash, mediaOrigin)
|
|
||||||
if err != nil && err == sql.ErrNoRows {
|
if err != nil && err == sql.ErrNoRows {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
@ -82,40 +63,36 @@ func (d *Database) GetMediaMetadataByHash(
|
||||||
|
|
||||||
// StoreThumbnail inserts the metadata about the thumbnail into the database.
|
// StoreThumbnail inserts the metadata about the thumbnail into the database.
|
||||||
// Returns an error if the combination of MediaID and Origin are not unique in the table.
|
// Returns an error if the combination of MediaID and Origin are not unique in the table.
|
||||||
func (d *Database) StoreThumbnail(
|
func (d Database) StoreThumbnail(ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata) error {
|
||||||
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
) error {
|
return d.Thumbnails.InsertThumbnail(ctx, txn, thumbnailMetadata)
|
||||||
return d.statements.thumbnail.insertThumbnail(ctx, thumbnailMetadata)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetThumbnail returns metadata about a specific thumbnail.
|
// GetThumbnail returns metadata about a specific thumbnail.
|
||||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
// The media could have been uploaded to this server or fetched from another server and cached here.
|
||||||
// Returns nil metadata if there is no metadata associated with this thumbnail.
|
// Returns nil metadata if there is no metadata associated with this thumbnail.
|
||||||
func (d *Database) GetThumbnail(
|
func (d Database) GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error) {
|
||||||
ctx context.Context,
|
metadata, err := d.Thumbnails.SelectThumbnail(ctx, nil, mediaID, mediaOrigin, width, height, resizeMethod)
|
||||||
mediaID types.MediaID,
|
if err != nil {
|
||||||
mediaOrigin gomatrixserverlib.ServerName,
|
if err == sql.ErrNoRows {
|
||||||
width, height int,
|
return nil, nil
|
||||||
resizeMethod string,
|
}
|
||||||
) (*types.ThumbnailMetadata, error) {
|
return nil, err
|
||||||
thumbnailMetadata, err := d.statements.thumbnail.selectThumbnail(
|
|
||||||
ctx, mediaID, mediaOrigin, width, height, resizeMethod,
|
|
||||||
)
|
|
||||||
if err != nil && err == sql.ErrNoRows {
|
|
||||||
return nil, nil
|
|
||||||
}
|
}
|
||||||
return thumbnailMetadata, err
|
return metadata, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetThumbnails returns metadata about all thumbnails for a specific media stored on this server.
|
// GetThumbnails returns metadata about all thumbnails for a specific media stored on this server.
|
||||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
// The media could have been uploaded to this server or fetched from another server and cached here.
|
||||||
// Returns nil metadata if there are no thumbnails associated with this media.
|
// Returns nil metadata if there are no thumbnails associated with this media.
|
||||||
func (d *Database) GetThumbnails(
|
func (d Database) GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) ([]*types.ThumbnailMetadata, error) {
|
||||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
metadatas, err := d.Thumbnails.SelectThumbnails(ctx, nil, mediaID, mediaOrigin)
|
||||||
) ([]*types.ThumbnailMetadata, error) {
|
if err != nil {
|
||||||
thumbnails, err := d.statements.thumbnail.selectThumbnails(ctx, mediaID, mediaOrigin)
|
if err == sql.ErrNoRows {
|
||||||
if err != nil && err == sql.ErrNoRows {
|
return nil, nil
|
||||||
return nil, nil
|
}
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
return thumbnails, err
|
return metadatas, err
|
||||||
}
|
}
|
||||||
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
@ -66,57 +67,53 @@ SELECT content_type, file_size_bytes, creation_ts, upload_name, media_id, user_i
|
||||||
|
|
||||||
type mediaStatements struct {
|
type mediaStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.Writer
|
|
||||||
insertMediaStmt *sql.Stmt
|
insertMediaStmt *sql.Stmt
|
||||||
selectMediaStmt *sql.Stmt
|
selectMediaStmt *sql.Stmt
|
||||||
selectMediaByHashStmt *sql.Stmt
|
selectMediaByHashStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mediaStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
func NewSQLiteMediaRepositoryTable(db *sql.DB) (tables.MediaRepository, error) {
|
||||||
s.db = db
|
s := &mediaStatements{
|
||||||
s.writer = writer
|
db: db,
|
||||||
|
}
|
||||||
_, err = db.Exec(mediaSchema)
|
_, err := db.Exec(mediaSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return statementList{
|
return s, sqlutil.StatementList{
|
||||||
{&s.insertMediaStmt, insertMediaSQL},
|
{&s.insertMediaStmt, insertMediaSQL},
|
||||||
{&s.selectMediaStmt, selectMediaSQL},
|
{&s.selectMediaStmt, selectMediaSQL},
|
||||||
{&s.selectMediaByHashStmt, selectMediaByHashSQL},
|
{&s.selectMediaByHashStmt, selectMediaByHashSQL},
|
||||||
}.prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mediaStatements) insertMedia(
|
func (s *mediaStatements) InsertMedia(
|
||||||
ctx context.Context, mediaMetadata *types.MediaMetadata,
|
ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata,
|
||||||
) error {
|
) error {
|
||||||
mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
|
mediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
|
||||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
_, err := sqlutil.TxStmtContext(ctx, txn, s.insertMediaStmt).ExecContext(
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertMediaStmt)
|
ctx,
|
||||||
_, err := stmt.ExecContext(
|
mediaMetadata.MediaID,
|
||||||
ctx,
|
mediaMetadata.Origin,
|
||||||
mediaMetadata.MediaID,
|
mediaMetadata.ContentType,
|
||||||
mediaMetadata.Origin,
|
mediaMetadata.FileSizeBytes,
|
||||||
mediaMetadata.ContentType,
|
mediaMetadata.CreationTimestamp,
|
||||||
mediaMetadata.FileSizeBytes,
|
mediaMetadata.UploadName,
|
||||||
mediaMetadata.CreationTimestamp,
|
mediaMetadata.Base64Hash,
|
||||||
mediaMetadata.UploadName,
|
mediaMetadata.UserID,
|
||||||
mediaMetadata.Base64Hash,
|
)
|
||||||
mediaMetadata.UserID,
|
return err
|
||||||
)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mediaStatements) selectMedia(
|
func (s *mediaStatements) SelectMedia(
|
||||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||||
) (*types.MediaMetadata, error) {
|
) (*types.MediaMetadata, error) {
|
||||||
mediaMetadata := types.MediaMetadata{
|
mediaMetadata := types.MediaMetadata{
|
||||||
MediaID: mediaID,
|
MediaID: mediaID,
|
||||||
Origin: mediaOrigin,
|
Origin: mediaOrigin,
|
||||||
}
|
}
|
||||||
err := s.selectMediaStmt.QueryRowContext(
|
err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaStmt).QueryRowContext(
|
||||||
ctx, mediaMetadata.MediaID, mediaMetadata.Origin,
|
ctx, mediaMetadata.MediaID, mediaMetadata.Origin,
|
||||||
).Scan(
|
).Scan(
|
||||||
&mediaMetadata.ContentType,
|
&mediaMetadata.ContentType,
|
||||||
|
|
@ -129,14 +126,14 @@ func (s *mediaStatements) selectMedia(
|
||||||
return &mediaMetadata, err
|
return &mediaMetadata, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mediaStatements) selectMediaByHash(
|
func (s *mediaStatements) SelectMediaByHash(
|
||||||
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
|
ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
|
||||||
) (*types.MediaMetadata, error) {
|
) (*types.MediaMetadata, error) {
|
||||||
mediaMetadata := types.MediaMetadata{
|
mediaMetadata := types.MediaMetadata{
|
||||||
Base64Hash: mediaHash,
|
Base64Hash: mediaHash,
|
||||||
Origin: mediaOrigin,
|
Origin: mediaOrigin,
|
||||||
}
|
}
|
||||||
err := s.selectMediaStmt.QueryRowContext(
|
err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaByHashStmt).QueryRowContext(
|
||||||
ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin,
|
ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin,
|
||||||
).Scan(
|
).Scan(
|
||||||
&mediaMetadata.ContentType,
|
&mediaMetadata.ContentType,
|
||||||
|
|
|
||||||
|
|
@ -16,23 +16,30 @@
|
||||||
package sqlite3
|
package sqlite3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
// Import the postgres database driver.
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/mediaapi/storage/shared"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
type statements struct {
|
// NewDatabase opens a SQLIte database.
|
||||||
media mediaStatements
|
func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) {
|
||||||
thumbnail thumbnailStatements
|
db, err := sqlutil.Open(dbProperties)
|
||||||
}
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
func (s *statements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
|
||||||
if err = s.media.prepare(db, writer); err != nil {
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
if err = s.thumbnail.prepare(db, writer); err != nil {
|
mediaRepo, err := NewSQLiteMediaRepositoryTable(db)
|
||||||
return
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
thumbnails, err := NewSQLiteThumbnailsTable(db)
|
||||||
return
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &shared.Database{
|
||||||
|
MediaRepository: mediaRepo,
|
||||||
|
Thumbnails: thumbnails,
|
||||||
|
DB: db,
|
||||||
|
Writer: sqlutil.NewExclusiveWriter(),
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
@ -1,38 +0,0 @@
|
||||||
// Copyright 2017-2018 New Vector Ltd
|
|
||||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
// FIXME: This should be made internal!
|
|
||||||
|
|
||||||
package sqlite3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement.
|
|
||||||
type statementList []struct {
|
|
||||||
statement **sql.Stmt
|
|
||||||
sql string
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepare the SQL for each statement in the list and assign the result to the prepared statement.
|
|
||||||
func (s statementList) prepare(db *sql.DB) (err error) {
|
|
||||||
for _, statement := range s {
|
|
||||||
if *statement.statement, err = db.Prepare(statement.sql); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
@ -1,123 +0,0 @@
|
||||||
// Copyright 2017-2018 New Vector Ltd
|
|
||||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package sqlite3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
|
|
||||||
// Import the postgres database driver.
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Database is used to store metadata about a repository of media files.
|
|
||||||
type Database struct {
|
|
||||||
statements statements
|
|
||||||
db *sql.DB
|
|
||||||
writer sqlutil.Writer
|
|
||||||
}
|
|
||||||
|
|
||||||
// Open opens a postgres database.
|
|
||||||
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
|
|
||||||
d := Database{
|
|
||||||
writer: sqlutil.NewExclusiveWriter(),
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = d.statements.prepare(d.db, d.writer); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &d, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// StoreMediaMetadata inserts the metadata about the uploaded media into the database.
|
|
||||||
// Returns an error if the combination of MediaID and Origin are not unique in the table.
|
|
||||||
func (d *Database) StoreMediaMetadata(
|
|
||||||
ctx context.Context, mediaMetadata *types.MediaMetadata,
|
|
||||||
) error {
|
|
||||||
return d.statements.media.insertMedia(ctx, mediaMetadata)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMediaMetadata returns metadata about media stored on this server.
|
|
||||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
|
||||||
// Returns nil metadata if there is no metadata associated with this media.
|
|
||||||
func (d *Database) GetMediaMetadata(
|
|
||||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
|
||||||
) (*types.MediaMetadata, error) {
|
|
||||||
mediaMetadata, err := d.statements.media.selectMedia(ctx, mediaID, mediaOrigin)
|
|
||||||
if err != nil && err == sql.ErrNoRows {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return mediaMetadata, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMediaMetadataByHash returns metadata about media stored on this server.
|
|
||||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
|
||||||
// Returns nil metadata if there is no metadata associated with this media.
|
|
||||||
func (d *Database) GetMediaMetadataByHash(
|
|
||||||
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
|
|
||||||
) (*types.MediaMetadata, error) {
|
|
||||||
mediaMetadata, err := d.statements.media.selectMediaByHash(ctx, mediaHash, mediaOrigin)
|
|
||||||
if err != nil && err == sql.ErrNoRows {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return mediaMetadata, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// StoreThumbnail inserts the metadata about the thumbnail into the database.
|
|
||||||
// Returns an error if the combination of MediaID and Origin are not unique in the table.
|
|
||||||
func (d *Database) StoreThumbnail(
|
|
||||||
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
|
|
||||||
) error {
|
|
||||||
return d.statements.thumbnail.insertThumbnail(ctx, thumbnailMetadata)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetThumbnail returns metadata about a specific thumbnail.
|
|
||||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
|
||||||
// Returns nil metadata if there is no metadata associated with this thumbnail.
|
|
||||||
func (d *Database) GetThumbnail(
|
|
||||||
ctx context.Context,
|
|
||||||
mediaID types.MediaID,
|
|
||||||
mediaOrigin gomatrixserverlib.ServerName,
|
|
||||||
width, height int,
|
|
||||||
resizeMethod string,
|
|
||||||
) (*types.ThumbnailMetadata, error) {
|
|
||||||
thumbnailMetadata, err := d.statements.thumbnail.selectThumbnail(
|
|
||||||
ctx, mediaID, mediaOrigin, width, height, resizeMethod,
|
|
||||||
)
|
|
||||||
if err != nil && err == sql.ErrNoRows {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return thumbnailMetadata, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetThumbnails returns metadata about all thumbnails for a specific media stored on this server.
|
|
||||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
|
||||||
// Returns nil metadata if there are no thumbnails associated with this media.
|
|
||||||
func (d *Database) GetThumbnails(
|
|
||||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
|
||||||
) ([]*types.ThumbnailMetadata, error) {
|
|
||||||
thumbnails, err := d.statements.thumbnail.selectThumbnails(ctx, mediaID, mediaOrigin)
|
|
||||||
if err != nil && err == sql.ErrNoRows {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return thumbnails, err
|
|
||||||
}
|
|
||||||
|
|
@ -22,6 +22,7 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
@ -54,55 +55,48 @@ SELECT content_type, file_size_bytes, creation_ts FROM mediaapi_thumbnail WHERE
|
||||||
|
|
||||||
// Note: this selects all thumbnails for a media_origin and media_id
|
// Note: this selects all thumbnails for a media_origin and media_id
|
||||||
const selectThumbnailsSQL = `
|
const selectThumbnailsSQL = `
|
||||||
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2
|
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 ORDER BY creation_ts ASC
|
||||||
`
|
`
|
||||||
|
|
||||||
type thumbnailStatements struct {
|
type thumbnailStatements struct {
|
||||||
db *sql.DB
|
|
||||||
writer sqlutil.Writer
|
|
||||||
insertThumbnailStmt *sql.Stmt
|
insertThumbnailStmt *sql.Stmt
|
||||||
selectThumbnailStmt *sql.Stmt
|
selectThumbnailStmt *sql.Stmt
|
||||||
selectThumbnailsStmt *sql.Stmt
|
selectThumbnailsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *thumbnailStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
func NewSQLiteThumbnailsTable(db *sql.DB) (tables.Thumbnails, error) {
|
||||||
_, err = db.Exec(thumbnailSchema)
|
s := &thumbnailStatements{}
|
||||||
|
_, err := db.Exec(thumbnailSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
s.db = db
|
|
||||||
s.writer = writer
|
|
||||||
|
|
||||||
return statementList{
|
return s, sqlutil.StatementList{
|
||||||
{&s.insertThumbnailStmt, insertThumbnailSQL},
|
{&s.insertThumbnailStmt, insertThumbnailSQL},
|
||||||
{&s.selectThumbnailStmt, selectThumbnailSQL},
|
{&s.selectThumbnailStmt, selectThumbnailSQL},
|
||||||
{&s.selectThumbnailsStmt, selectThumbnailsSQL},
|
{&s.selectThumbnailsStmt, selectThumbnailsSQL},
|
||||||
}.prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *thumbnailStatements) insertThumbnail(
|
func (s *thumbnailStatements) InsertThumbnail(ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata) error {
|
||||||
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
|
thumbnailMetadata.MediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
|
||||||
) error {
|
_, err := sqlutil.TxStmtContext(ctx, txn, s.insertThumbnailStmt).ExecContext(
|
||||||
thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
|
ctx,
|
||||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
thumbnailMetadata.MediaMetadata.MediaID,
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertThumbnailStmt)
|
thumbnailMetadata.MediaMetadata.Origin,
|
||||||
_, err := stmt.ExecContext(
|
thumbnailMetadata.MediaMetadata.ContentType,
|
||||||
ctx,
|
thumbnailMetadata.MediaMetadata.FileSizeBytes,
|
||||||
thumbnailMetadata.MediaMetadata.MediaID,
|
thumbnailMetadata.MediaMetadata.CreationTimestamp,
|
||||||
thumbnailMetadata.MediaMetadata.Origin,
|
thumbnailMetadata.ThumbnailSize.Width,
|
||||||
thumbnailMetadata.MediaMetadata.ContentType,
|
thumbnailMetadata.ThumbnailSize.Height,
|
||||||
thumbnailMetadata.MediaMetadata.FileSizeBytes,
|
thumbnailMetadata.ThumbnailSize.ResizeMethod,
|
||||||
thumbnailMetadata.MediaMetadata.CreationTimestamp,
|
)
|
||||||
thumbnailMetadata.ThumbnailSize.Width,
|
return err
|
||||||
thumbnailMetadata.ThumbnailSize.Height,
|
|
||||||
thumbnailMetadata.ThumbnailSize.ResizeMethod,
|
|
||||||
)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *thumbnailStatements) selectThumbnail(
|
func (s *thumbnailStatements) SelectThumbnail(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
txn *sql.Tx,
|
||||||
mediaID types.MediaID,
|
mediaID types.MediaID,
|
||||||
mediaOrigin gomatrixserverlib.ServerName,
|
mediaOrigin gomatrixserverlib.ServerName,
|
||||||
width, height int,
|
width, height int,
|
||||||
|
|
@ -119,7 +113,7 @@ func (s *thumbnailStatements) selectThumbnail(
|
||||||
ResizeMethod: resizeMethod,
|
ResizeMethod: resizeMethod,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
err := s.selectThumbnailStmt.QueryRowContext(
|
err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailStmt).QueryRowContext(
|
||||||
ctx,
|
ctx,
|
||||||
thumbnailMetadata.MediaMetadata.MediaID,
|
thumbnailMetadata.MediaMetadata.MediaID,
|
||||||
thumbnailMetadata.MediaMetadata.Origin,
|
thumbnailMetadata.MediaMetadata.Origin,
|
||||||
|
|
@ -134,10 +128,11 @@ func (s *thumbnailStatements) selectThumbnail(
|
||||||
return &thumbnailMetadata, err
|
return &thumbnailMetadata, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *thumbnailStatements) selectThumbnails(
|
func (s *thumbnailStatements) SelectThumbnails(
|
||||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
ctx context.Context, txn *sql.Tx, mediaID types.MediaID,
|
||||||
|
mediaOrigin gomatrixserverlib.ServerName,
|
||||||
) ([]*types.ThumbnailMetadata, error) {
|
) ([]*types.ThumbnailMetadata, error) {
|
||||||
rows, err := s.selectThumbnailsStmt.QueryContext(
|
rows, err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailsStmt).QueryContext(
|
||||||
ctx, mediaID, mediaOrigin,
|
ctx, mediaID, mediaOrigin,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -25,13 +25,13 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Open opens a postgres database.
|
// NewMediaAPIDatasource opens a database connection.
|
||||||
func Open(dbProperties *config.DatabaseOptions) (Database, error) {
|
func NewMediaAPIDatasource(dbProperties *config.DatabaseOptions) (Database, error) {
|
||||||
switch {
|
switch {
|
||||||
case dbProperties.ConnectionString.IsSQLite():
|
case dbProperties.ConnectionString.IsSQLite():
|
||||||
return sqlite3.Open(dbProperties)
|
return sqlite3.NewDatabase(dbProperties)
|
||||||
case dbProperties.ConnectionString.IsPostgres():
|
case dbProperties.ConnectionString.IsPostgres():
|
||||||
return postgres.Open(dbProperties)
|
return postgres.NewDatabase(dbProperties)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unexpected database type")
|
return nil, fmt.Errorf("unexpected database type")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
135
mediaapi/storage/storage_test.go
Normal file
135
mediaapi/storage/storage_test.go
Normal file
|
|
@ -0,0 +1,135 @@
|
||||||
|
package storage_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/mediaapi/storage"
|
||||||
|
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
||||||
|
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||||
|
db, err := storage.NewMediaAPIDatasource(&config.DatabaseOptions{
|
||||||
|
ConnectionString: config.DataSource(connStr),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewSyncServerDatasource returned %s", err)
|
||||||
|
}
|
||||||
|
return db, close
|
||||||
|
}
|
||||||
|
func TestMediaRepository(t *testing.T) {
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, close := mustCreateDatabase(t, dbType)
|
||||||
|
defer close()
|
||||||
|
ctx := context.Background()
|
||||||
|
t.Run("can insert media & query media", func(t *testing.T) {
|
||||||
|
metadata := &types.MediaMetadata{
|
||||||
|
MediaID: "testing",
|
||||||
|
Origin: "localhost",
|
||||||
|
ContentType: "image/png",
|
||||||
|
FileSizeBytes: 10,
|
||||||
|
UploadName: "upload test",
|
||||||
|
Base64Hash: "dGVzdGluZw==",
|
||||||
|
UserID: "@alice:localhost",
|
||||||
|
}
|
||||||
|
if err := db.StoreMediaMetadata(ctx, metadata); err != nil {
|
||||||
|
t.Fatalf("unable to store media metadata: %v", err)
|
||||||
|
}
|
||||||
|
// query by media id
|
||||||
|
gotMetadata, err := db.GetMediaMetadata(ctx, metadata.MediaID, metadata.Origin)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to query media metadata: %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(metadata, gotMetadata) {
|
||||||
|
t.Fatalf("expected metadata %+v, got %v", metadata, gotMetadata)
|
||||||
|
}
|
||||||
|
// query by media hash
|
||||||
|
gotMetadata, err = db.GetMediaMetadataByHash(ctx, metadata.Base64Hash, metadata.Origin)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to query media metadata by hash: %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(metadata, gotMetadata) {
|
||||||
|
t.Fatalf("expected metadata %+v, got %v", metadata, gotMetadata)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThumbnailsStorage(t *testing.T) {
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, close := mustCreateDatabase(t, dbType)
|
||||||
|
defer close()
|
||||||
|
ctx := context.Background()
|
||||||
|
t.Run("can insert thumbnails & query media", func(t *testing.T) {
|
||||||
|
thumbnails := []*types.ThumbnailMetadata{
|
||||||
|
{
|
||||||
|
MediaMetadata: &types.MediaMetadata{
|
||||||
|
MediaID: "testing",
|
||||||
|
Origin: "localhost",
|
||||||
|
ContentType: "image/png",
|
||||||
|
FileSizeBytes: 6,
|
||||||
|
},
|
||||||
|
ThumbnailSize: types.ThumbnailSize{
|
||||||
|
Width: 5,
|
||||||
|
Height: 5,
|
||||||
|
ResizeMethod: types.Crop,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
MediaMetadata: &types.MediaMetadata{
|
||||||
|
MediaID: "testing",
|
||||||
|
Origin: "localhost",
|
||||||
|
ContentType: "image/png",
|
||||||
|
FileSizeBytes: 7,
|
||||||
|
},
|
||||||
|
ThumbnailSize: types.ThumbnailSize{
|
||||||
|
Width: 1,
|
||||||
|
Height: 1,
|
||||||
|
ResizeMethod: types.Scale,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for i := range thumbnails {
|
||||||
|
if err := db.StoreThumbnail(ctx, thumbnails[i]); err != nil {
|
||||||
|
t.Fatalf("unable to store thumbnail metadata: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// query by single thumbnail
|
||||||
|
gotMetadata, err := db.GetThumbnail(ctx,
|
||||||
|
thumbnails[0].MediaMetadata.MediaID,
|
||||||
|
thumbnails[0].MediaMetadata.Origin,
|
||||||
|
thumbnails[0].ThumbnailSize.Width, thumbnails[0].ThumbnailSize.Height,
|
||||||
|
thumbnails[0].ThumbnailSize.ResizeMethod,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to query thumbnail metadata: %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata) {
|
||||||
|
t.Fatalf("expected metadata %+v, got %+v", thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(thumbnails[0].ThumbnailSize, gotMetadata.ThumbnailSize) {
|
||||||
|
t.Fatalf("expected metadata %+v, got %+v", thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata)
|
||||||
|
}
|
||||||
|
// query by all thumbnails
|
||||||
|
gotMediadatas, err := db.GetThumbnails(ctx, thumbnails[0].MediaMetadata.MediaID, thumbnails[0].MediaMetadata.Origin)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to query media metadata by hash: %v", err)
|
||||||
|
}
|
||||||
|
if len(gotMediadatas) != len(thumbnails) {
|
||||||
|
t.Fatalf("expected %d stored thumbnail metadata, got %d", len(thumbnails), len(gotMediadatas))
|
||||||
|
}
|
||||||
|
for i := range gotMediadatas {
|
||||||
|
if !reflect.DeepEqual(thumbnails[i].MediaMetadata, gotMediadatas[i].MediaMetadata) {
|
||||||
|
t.Fatalf("expected metadata %+v, got %v", thumbnails[i].MediaMetadata, gotMediadatas[i].MediaMetadata)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(thumbnails[i].ThumbnailSize, gotMediadatas[i].ThumbnailSize) {
|
||||||
|
t.Fatalf("expected metadata %+v, got %v", thumbnails[i].ThumbnailSize, gotMediadatas[i].ThumbnailSize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
@ -22,10 +22,10 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// Open opens a postgres database.
|
// Open opens a postgres database.
|
||||||
func Open(dbProperties *config.DatabaseOptions) (Database, error) {
|
func NewMediaAPIDatasource(dbProperties *config.DatabaseOptions) (Database, error) {
|
||||||
switch {
|
switch {
|
||||||
case dbProperties.ConnectionString.IsSQLite():
|
case dbProperties.ConnectionString.IsSQLite():
|
||||||
return sqlite3.Open(dbProperties)
|
return sqlite3.NewDatabase(dbProperties)
|
||||||
case dbProperties.ConnectionString.IsPostgres():
|
case dbProperties.ConnectionString.IsPostgres():
|
||||||
return nil, fmt.Errorf("can't use Postgres implementation")
|
return nil, fmt.Errorf("can't use Postgres implementation")
|
||||||
default:
|
default:
|
||||||
|
|
|
||||||
46
mediaapi/storage/tables/interface.go
Normal file
46
mediaapi/storage/tables/interface.go
Normal file
|
|
@ -0,0 +1,46 @@
|
||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package tables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Thumbnails interface {
|
||||||
|
InsertThumbnail(ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata) error
|
||||||
|
SelectThumbnail(
|
||||||
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||||
|
width, height int,
|
||||||
|
resizeMethod string,
|
||||||
|
) (*types.ThumbnailMetadata, error)
|
||||||
|
SelectThumbnails(
|
||||||
|
ctx context.Context, txn *sql.Tx, mediaID types.MediaID,
|
||||||
|
mediaOrigin gomatrixserverlib.ServerName,
|
||||||
|
) ([]*types.ThumbnailMetadata, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MediaRepository interface {
|
||||||
|
InsertMedia(ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata) error
|
||||||
|
SelectMedia(ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error)
|
||||||
|
SelectMediaByHash(
|
||||||
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
|
||||||
|
) (*types.MediaMetadata, error)
|
||||||
|
}
|
||||||
|
|
@ -45,16 +45,13 @@ type RequestMethod string
|
||||||
// MatrixUserID is a Matrix user ID string in the form @user:domain e.g. @alice:matrix.org
|
// MatrixUserID is a Matrix user ID string in the form @user:domain e.g. @alice:matrix.org
|
||||||
type MatrixUserID string
|
type MatrixUserID string
|
||||||
|
|
||||||
// UnixMs is the milliseconds since the Unix epoch
|
|
||||||
type UnixMs int64
|
|
||||||
|
|
||||||
// MediaMetadata is metadata associated with a media file
|
// MediaMetadata is metadata associated with a media file
|
||||||
type MediaMetadata struct {
|
type MediaMetadata struct {
|
||||||
MediaID MediaID
|
MediaID MediaID
|
||||||
Origin gomatrixserverlib.ServerName
|
Origin gomatrixserverlib.ServerName
|
||||||
ContentType ContentType
|
ContentType ContentType
|
||||||
FileSizeBytes FileSizeBytes
|
FileSizeBytes FileSizeBytes
|
||||||
CreationTimestamp UnixMs
|
CreationTimestamp gomatrixserverlib.Timestamp
|
||||||
UploadName Filename
|
UploadName Filename
|
||||||
Base64Hash Base64Hash
|
Base64Hash Base64Hash
|
||||||
UserID MatrixUserID
|
UserID MatrixUserID
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ import (
|
||||||
type Notifier struct {
|
type Notifier struct {
|
||||||
lock *sync.RWMutex
|
lock *sync.RWMutex
|
||||||
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
|
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
|
||||||
roomIDToJoinedUsers map[string]userIDSet
|
roomIDToJoinedUsers map[string]*userIDSet
|
||||||
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
|
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
|
||||||
roomIDToPeekingDevices map[string]peekingDeviceSet
|
roomIDToPeekingDevices map[string]peekingDeviceSet
|
||||||
// The latest sync position
|
// The latest sync position
|
||||||
|
|
@ -54,7 +54,7 @@ type Notifier struct {
|
||||||
// the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase).
|
// the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase).
|
||||||
func NewNotifier() *Notifier {
|
func NewNotifier() *Notifier {
|
||||||
return &Notifier{
|
return &Notifier{
|
||||||
roomIDToJoinedUsers: make(map[string]userIDSet),
|
roomIDToJoinedUsers: make(map[string]*userIDSet),
|
||||||
roomIDToPeekingDevices: make(map[string]peekingDeviceSet),
|
roomIDToPeekingDevices: make(map[string]peekingDeviceSet),
|
||||||
userDeviceStreams: make(map[string]map[string]*UserDeviceStream),
|
userDeviceStreams: make(map[string]map[string]*UserDeviceStream),
|
||||||
lock: &sync.RWMutex{},
|
lock: &sync.RWMutex{},
|
||||||
|
|
@ -262,7 +262,7 @@ func (n *Notifier) SharedUsers(userID string) []string {
|
||||||
func (n *Notifier) _sharedUsers(userID string) []string {
|
func (n *Notifier) _sharedUsers(userID string) []string {
|
||||||
n._sharedUserMap[userID] = struct{}{}
|
n._sharedUserMap[userID] = struct{}{}
|
||||||
for roomID, users := range n.roomIDToJoinedUsers {
|
for roomID, users := range n.roomIDToJoinedUsers {
|
||||||
if _, ok := users[userID]; !ok {
|
if ok := users.isIn(userID); !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for _, userID := range n._joinedUsers(roomID) {
|
for _, userID := range n._joinedUsers(roomID) {
|
||||||
|
|
@ -282,8 +282,11 @@ func (n *Notifier) IsSharedUser(userA, userB string) bool {
|
||||||
defer n.lock.RUnlock()
|
defer n.lock.RUnlock()
|
||||||
var okA, okB bool
|
var okA, okB bool
|
||||||
for _, users := range n.roomIDToJoinedUsers {
|
for _, users := range n.roomIDToJoinedUsers {
|
||||||
_, okA = users[userA]
|
okA = users.isIn(userA)
|
||||||
_, okB = users[userB]
|
if !okA {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
okB = users.isIn(userB)
|
||||||
if okA && okB {
|
if okA && okB {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
@ -345,11 +348,12 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) {
|
||||||
// This is just the bulk form of addJoinedUser
|
// This is just the bulk form of addJoinedUser
|
||||||
for roomID, userIDs := range roomIDToUserIDs {
|
for roomID, userIDs := range roomIDToUserIDs {
|
||||||
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
|
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
|
||||||
n.roomIDToJoinedUsers[roomID] = make(userIDSet, len(userIDs))
|
n.roomIDToJoinedUsers[roomID] = newUserIDSet(len(userIDs))
|
||||||
}
|
}
|
||||||
for _, userID := range userIDs {
|
for _, userID := range userIDs {
|
||||||
n.roomIDToJoinedUsers[roomID].add(userID)
|
n.roomIDToJoinedUsers[roomID].add(userID)
|
||||||
}
|
}
|
||||||
|
n.roomIDToJoinedUsers[roomID].precompute()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -440,16 +444,18 @@ func (n *Notifier) _fetchUserStreams(userID string) []*UserDeviceStream {
|
||||||
|
|
||||||
func (n *Notifier) _addJoinedUser(roomID, userID string) {
|
func (n *Notifier) _addJoinedUser(roomID, userID string) {
|
||||||
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
|
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
|
||||||
n.roomIDToJoinedUsers[roomID] = make(userIDSet)
|
n.roomIDToJoinedUsers[roomID] = newUserIDSet(8)
|
||||||
}
|
}
|
||||||
n.roomIDToJoinedUsers[roomID].add(userID)
|
n.roomIDToJoinedUsers[roomID].add(userID)
|
||||||
|
n.roomIDToJoinedUsers[roomID].precompute()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Notifier) _removeJoinedUser(roomID, userID string) {
|
func (n *Notifier) _removeJoinedUser(roomID, userID string) {
|
||||||
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
|
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
|
||||||
n.roomIDToJoinedUsers[roomID] = make(userIDSet)
|
n.roomIDToJoinedUsers[roomID] = newUserIDSet(8)
|
||||||
}
|
}
|
||||||
n.roomIDToJoinedUsers[roomID].remove(userID)
|
n.roomIDToJoinedUsers[roomID].remove(userID)
|
||||||
|
n.roomIDToJoinedUsers[roomID].precompute()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Notifier) JoinedUsers(roomID string) (userIDs []string) {
|
func (n *Notifier) JoinedUsers(roomID string) (userIDs []string) {
|
||||||
|
|
@ -521,19 +527,52 @@ func (n *Notifier) _removeEmptyUserStreams() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// A string set, mainly existing for improving clarity of structs in this file.
|
// A string set, mainly existing for improving clarity of structs in this file.
|
||||||
type userIDSet map[string]struct{}
|
type userIDSet struct {
|
||||||
|
sync.Mutex
|
||||||
func (s userIDSet) add(str string) {
|
set map[string]struct{}
|
||||||
s[str] = struct{}{}
|
precomputed []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s userIDSet) remove(str string) {
|
func newUserIDSet(cap int) *userIDSet {
|
||||||
delete(s, str)
|
return &userIDSet{
|
||||||
|
set: make(map[string]struct{}, cap),
|
||||||
|
precomputed: nil,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s userIDSet) values() (vals []string) {
|
func (s *userIDSet) add(str string) {
|
||||||
vals = make([]string, 0, len(s))
|
s.Lock()
|
||||||
for str := range s {
|
defer s.Unlock()
|
||||||
|
s.set[str] = struct{}{}
|
||||||
|
s.precomputed = s.precomputed[:0] // invalidate cache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userIDSet) remove(str string) {
|
||||||
|
s.Lock()
|
||||||
|
defer s.Unlock()
|
||||||
|
delete(s.set, str)
|
||||||
|
s.precomputed = s.precomputed[:0] // invalidate cache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userIDSet) precompute() {
|
||||||
|
s.Lock()
|
||||||
|
defer s.Unlock()
|
||||||
|
s.precomputed = s.values()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userIDSet) isIn(str string) bool {
|
||||||
|
s.Lock()
|
||||||
|
defer s.Unlock()
|
||||||
|
_, ok := s.set[str]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userIDSet) values() (vals []string) {
|
||||||
|
if len(s.precomputed) > 0 {
|
||||||
|
return s.precomputed // only return if not invalidated
|
||||||
|
}
|
||||||
|
vals = make([]string, 0, len(s.set))
|
||||||
|
for str := range s.set {
|
||||||
vals = append(vals, str)
|
vals = append(vals, str)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,9 @@ func Context(
|
||||||
Headers: nil,
|
Headers: nil,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
filter.Rooms = append(filter.Rooms, roomID)
|
if filter.Rooms != nil {
|
||||||
|
*filter.Rooms = append(*filter.Rooms, roomID)
|
||||||
|
}
|
||||||
|
|
||||||
ctx := req.Context()
|
ctx := req.Context()
|
||||||
membershipRes := roomserver.QueryMembershipForUserResponse{}
|
membershipRes := roomserver.QueryMembershipForUserResponse{}
|
||||||
|
|
|
||||||
|
|
@ -262,12 +262,8 @@ func (r *messagesReq) retrieveEvents() (
|
||||||
clientEvents []gomatrixserverlib.ClientEvent, start,
|
clientEvents []gomatrixserverlib.ClientEvent, start,
|
||||||
end types.TopologyToken, err error,
|
end types.TopologyToken, err error,
|
||||||
) {
|
) {
|
||||||
eventFilter := r.filter
|
|
||||||
|
|
||||||
// Retrieve the events from the local database.
|
// Retrieve the events from the local database.
|
||||||
streamEvents, err := r.db.GetEventsInTopologicalRange(
|
streamEvents, err := r.db.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering)
|
||||||
r.ctx, r.from, r.to, r.roomID, eventFilter.Limit, r.backwardOrdering,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("GetEventsInRange: %w", err)
|
err = fmt.Errorf("GetEventsInRange: %w", err)
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -104,8 +104,8 @@ type Database interface {
|
||||||
// DeletePeek deletes all peeks for a given room by a given user
|
// DeletePeek deletes all peeks for a given room by a given user
|
||||||
// Returns an error if there was a problem communicating with the database.
|
// Returns an error if there was a problem communicating with the database.
|
||||||
DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error)
|
DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error)
|
||||||
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit.
|
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last.
|
||||||
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error)
|
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error)
|
||||||
// EventPositionInTopology returns the depth and stream position of the given event.
|
// EventPositionInTopology returns the depth and stream position of the given event.
|
||||||
EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
|
EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
|
||||||
// BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events.
|
// BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events.
|
||||||
|
|
|
||||||
|
|
@ -47,14 +47,10 @@ const selectBackwardExtremitiesForRoomSQL = "" +
|
||||||
const deleteBackwardExtremitySQL = "" +
|
const deleteBackwardExtremitySQL = "" +
|
||||||
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
|
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
|
||||||
|
|
||||||
const deleteBackwardExtremitiesForRoomSQL = "" +
|
|
||||||
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1"
|
|
||||||
|
|
||||||
type backwardExtremitiesStatements struct {
|
type backwardExtremitiesStatements struct {
|
||||||
insertBackwardExtremityStmt *sql.Stmt
|
insertBackwardExtremityStmt *sql.Stmt
|
||||||
selectBackwardExtremitiesForRoomStmt *sql.Stmt
|
selectBackwardExtremitiesForRoomStmt *sql.Stmt
|
||||||
deleteBackwardExtremityStmt *sql.Stmt
|
deleteBackwardExtremityStmt *sql.Stmt
|
||||||
deleteBackwardExtremitiesForRoomStmt *sql.Stmt
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
|
func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
|
||||||
|
|
@ -72,9 +68,6 @@ func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremiti
|
||||||
if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil {
|
if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -113,10 +106,3 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
|
||||||
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
|
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom(
|
|
||||||
ctx context.Context, txn *sql.Tx, roomID string,
|
|
||||||
) (err error) {
|
|
||||||
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -233,9 +233,10 @@ func (s *currentRoomStateStatements) SelectCurrentState(
|
||||||
excludeEventIDs []string,
|
excludeEventIDs []string,
|
||||||
) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt)
|
stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt)
|
||||||
|
senders, notSenders := getSendersStateFilterFilter(stateFilter)
|
||||||
rows, err := stmt.QueryContext(ctx, roomID,
|
rows, err := stmt.QueryContext(ctx, roomID,
|
||||||
pq.StringArray(stateFilter.Senders),
|
pq.StringArray(senders),
|
||||||
pq.StringArray(stateFilter.NotSenders),
|
pq.StringArray(notSenders),
|
||||||
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)),
|
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)),
|
||||||
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)),
|
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)),
|
||||||
stateFilter.ContainsURL,
|
stateFilter.ContainsURL,
|
||||||
|
|
|
||||||
|
|
@ -16,21 +16,45 @@ package postgres
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
// filterConvertWildcardToSQL converts wildcards as defined in
|
// filterConvertWildcardToSQL converts wildcards as defined in
|
||||||
// https://matrix.org/docs/spec/client_server/r0.3.0.html#post-matrix-client-r0-user-userid-filter
|
// https://matrix.org/docs/spec/client_server/r0.3.0.html#post-matrix-client-r0-user-userid-filter
|
||||||
// to SQL wildcards that can be used with LIKE()
|
// to SQL wildcards that can be used with LIKE()
|
||||||
func filterConvertTypeWildcardToSQL(values []string) []string {
|
func filterConvertTypeWildcardToSQL(values *[]string) []string {
|
||||||
if values == nil {
|
if values == nil {
|
||||||
// Return nil instead of []string{} so IS NULL can work correctly when
|
// Return nil instead of []string{} so IS NULL can work correctly when
|
||||||
// the return value is passed into SQL queries
|
// the return value is passed into SQL queries
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ret := make([]string, len(values))
|
v := *values
|
||||||
for i := range values {
|
ret := make([]string, len(v))
|
||||||
ret[i] = strings.Replace(values[i], "*", "%", -1)
|
for i := range v {
|
||||||
|
ret[i] = strings.Replace(v[i], "*", "%", -1)
|
||||||
}
|
}
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Replace when Dendrite uses Go 1.18
|
||||||
|
func getSendersRoomEventFilter(filter *gomatrixserverlib.RoomEventFilter) (senders []string, notSenders []string) {
|
||||||
|
if filter.Senders != nil {
|
||||||
|
senders = *filter.Senders
|
||||||
|
}
|
||||||
|
if filter.NotSenders != nil {
|
||||||
|
notSenders = *filter.NotSenders
|
||||||
|
}
|
||||||
|
return senders, notSenders
|
||||||
|
}
|
||||||
|
|
||||||
|
func getSendersStateFilterFilter(filter *gomatrixserverlib.StateFilter) (senders []string, notSenders []string) {
|
||||||
|
if filter.Senders != nil {
|
||||||
|
senders = *filter.Senders
|
||||||
|
}
|
||||||
|
if filter.NotSenders != nil {
|
||||||
|
notSenders = *filter.NotSenders
|
||||||
|
}
|
||||||
|
return senders, notSenders
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -56,12 +56,6 @@ const upsertMembershipSQL = "" +
|
||||||
" ON CONFLICT ON CONSTRAINT syncapi_memberships_unique" +
|
" ON CONFLICT ON CONSTRAINT syncapi_memberships_unique" +
|
||||||
" DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6"
|
" DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6"
|
||||||
|
|
||||||
const selectMembershipSQL = "" +
|
|
||||||
"SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" +
|
|
||||||
" WHERE room_id = $1 AND user_id = $2 AND membership = ANY($3)" +
|
|
||||||
" ORDER BY stream_pos DESC" +
|
|
||||||
" LIMIT 1"
|
|
||||||
|
|
||||||
const selectMembershipCountSQL = "" +
|
const selectMembershipCountSQL = "" +
|
||||||
"SELECT COUNT(*) FROM (" +
|
"SELECT COUNT(*) FROM (" +
|
||||||
" SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" +
|
" SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" +
|
||||||
|
|
@ -69,7 +63,6 @@ const selectMembershipCountSQL = "" +
|
||||||
|
|
||||||
type membershipsStatements struct {
|
type membershipsStatements struct {
|
||||||
upsertMembershipStmt *sql.Stmt
|
upsertMembershipStmt *sql.Stmt
|
||||||
selectMembershipStmt *sql.Stmt
|
|
||||||
selectMembershipCountStmt *sql.Stmt
|
selectMembershipCountStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -82,9 +75,6 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
|
||||||
if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil {
|
if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if s.selectMembershipStmt, err = db.Prepare(selectMembershipSQL); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if s.selectMembershipCountStmt, err = db.Prepare(selectMembershipCountSQL); err != nil {
|
if s.selectMembershipCountStmt, err = db.Prepare(selectMembershipCountSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -111,14 +101,6 @@ func (s *membershipsStatements) UpsertMembership(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipsStatements) SelectMembership(
|
|
||||||
ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string,
|
|
||||||
) (eventID string, streamPos, topologyPos types.StreamPosition, err error) {
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectMembershipStmt)
|
|
||||||
err = stmt.QueryRowContext(ctx, roomID, userID, memberships).Scan(&eventID, &streamPos, &topologyPos)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *membershipsStatements) SelectMembershipCount(
|
func (s *membershipsStatements) SelectMembershipCount(
|
||||||
ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition,
|
ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition,
|
||||||
) (count int, err error) {
|
) (count int, err error) {
|
||||||
|
|
|
||||||
|
|
@ -81,6 +81,15 @@ const insertEventSQL = "" +
|
||||||
const selectEventsSQL = "" +
|
const selectEventsSQL = "" +
|
||||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)"
|
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)"
|
||||||
|
|
||||||
|
const selectEventsWithFilterSQL = "" +
|
||||||
|
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" +
|
||||||
|
" AND ( $2::text[] IS NULL OR sender = ANY($2) )" +
|
||||||
|
" AND ( $3::text[] IS NULL OR NOT(sender = ANY($3)) )" +
|
||||||
|
" AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" +
|
||||||
|
" AND ( $5::text[] IS NULL OR NOT(type LIKE ANY($5)) )" +
|
||||||
|
" AND ( $6::bool IS NULL OR contains_url = $6 )" +
|
||||||
|
" LIMIT $7"
|
||||||
|
|
||||||
const selectRecentEventsSQL = "" +
|
const selectRecentEventsSQL = "" +
|
||||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
|
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
|
||||||
" WHERE room_id = $1 AND id > $2 AND id <= $3" +
|
" WHERE room_id = $1 AND id > $2 AND id <= $3" +
|
||||||
|
|
@ -153,6 +162,7 @@ const selectContextAfterEventSQL = "" +
|
||||||
type outputRoomEventsStatements struct {
|
type outputRoomEventsStatements struct {
|
||||||
insertEventStmt *sql.Stmt
|
insertEventStmt *sql.Stmt
|
||||||
selectEventsStmt *sql.Stmt
|
selectEventsStmt *sql.Stmt
|
||||||
|
selectEventsWitFilterStmt *sql.Stmt
|
||||||
selectMaxEventIDStmt *sql.Stmt
|
selectMaxEventIDStmt *sql.Stmt
|
||||||
selectRecentEventsStmt *sql.Stmt
|
selectRecentEventsStmt *sql.Stmt
|
||||||
selectRecentEventsForSyncStmt *sql.Stmt
|
selectRecentEventsForSyncStmt *sql.Stmt
|
||||||
|
|
@ -174,6 +184,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
|
||||||
return s, sqlutil.StatementList{
|
return s, sqlutil.StatementList{
|
||||||
{&s.insertEventStmt, insertEventSQL},
|
{&s.insertEventStmt, insertEventSQL},
|
||||||
{&s.selectEventsStmt, selectEventsSQL},
|
{&s.selectEventsStmt, selectEventsSQL},
|
||||||
|
{&s.selectEventsWitFilterStmt, selectEventsWithFilterSQL},
|
||||||
{&s.selectMaxEventIDStmt, selectMaxEventIDSQL},
|
{&s.selectMaxEventIDStmt, selectMaxEventIDSQL},
|
||||||
{&s.selectRecentEventsStmt, selectRecentEventsSQL},
|
{&s.selectRecentEventsStmt, selectRecentEventsSQL},
|
||||||
{&s.selectRecentEventsForSyncStmt, selectRecentEventsForSyncSQL},
|
{&s.selectRecentEventsForSyncStmt, selectRecentEventsForSyncSQL},
|
||||||
|
|
@ -204,11 +215,11 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
|
||||||
stateFilter *gomatrixserverlib.StateFilter, roomIDs []string,
|
stateFilter *gomatrixserverlib.StateFilter, roomIDs []string,
|
||||||
) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
|
) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt)
|
stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt)
|
||||||
|
senders, notSenders := getSendersStateFilterFilter(stateFilter)
|
||||||
rows, err := stmt.QueryContext(
|
rows, err := stmt.QueryContext(
|
||||||
ctx, r.Low(), r.High(), pq.StringArray(roomIDs),
|
ctx, r.Low(), r.High(), pq.StringArray(roomIDs),
|
||||||
pq.StringArray(stateFilter.Senders),
|
pq.StringArray(senders),
|
||||||
pq.StringArray(stateFilter.NotSenders),
|
pq.StringArray(notSenders),
|
||||||
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)),
|
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)),
|
||||||
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)),
|
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)),
|
||||||
stateFilter.ContainsURL,
|
stateFilter.ContainsURL,
|
||||||
|
|
@ -310,7 +321,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
|
||||||
// Parse content as JSON and search for an "url" key
|
// Parse content as JSON and search for an "url" key
|
||||||
containsURL := false
|
containsURL := false
|
||||||
var content map[string]interface{}
|
var content map[string]interface{}
|
||||||
if json.Unmarshal(event.Content(), &content) != nil {
|
if json.Unmarshal(event.Content(), &content) == nil {
|
||||||
// Set containsURL to true if url is present
|
// Set containsURL to true if url is present
|
||||||
_, containsURL = content["url"]
|
_, containsURL = content["url"]
|
||||||
}
|
}
|
||||||
|
|
@ -353,10 +364,11 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
|
||||||
} else {
|
} else {
|
||||||
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt)
|
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt)
|
||||||
}
|
}
|
||||||
|
senders, notSenders := getSendersRoomEventFilter(eventFilter)
|
||||||
rows, err := stmt.QueryContext(
|
rows, err := stmt.QueryContext(
|
||||||
ctx, roomID, r.Low(), r.High(),
|
ctx, roomID, r.Low(), r.High(),
|
||||||
pq.StringArray(eventFilter.Senders),
|
pq.StringArray(senders),
|
||||||
pq.StringArray(eventFilter.NotSenders),
|
pq.StringArray(notSenders),
|
||||||
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)),
|
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)),
|
||||||
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)),
|
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)),
|
||||||
eventFilter.Limit+1,
|
eventFilter.Limit+1,
|
||||||
|
|
@ -398,11 +410,12 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
|
roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
|
||||||
) ([]types.StreamEvent, error) {
|
) ([]types.StreamEvent, error) {
|
||||||
|
senders, notSenders := getSendersRoomEventFilter(eventFilter)
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectEarlyEventsStmt)
|
stmt := sqlutil.TxStmt(txn, s.selectEarlyEventsStmt)
|
||||||
rows, err := stmt.QueryContext(
|
rows, err := stmt.QueryContext(
|
||||||
ctx, roomID, r.Low(), r.High(),
|
ctx, roomID, r.Low(), r.High(),
|
||||||
pq.StringArray(eventFilter.Senders),
|
pq.StringArray(senders),
|
||||||
pq.StringArray(eventFilter.NotSenders),
|
pq.StringArray(notSenders),
|
||||||
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)),
|
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)),
|
||||||
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)),
|
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)),
|
||||||
eventFilter.Limit,
|
eventFilter.Limit,
|
||||||
|
|
@ -427,15 +440,52 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
|
||||||
// selectEvents returns the events for the given event IDs. If an event is
|
// selectEvents returns the events for the given event IDs. If an event is
|
||||||
// missing from the database, it will be omitted.
|
// missing from the database, it will be omitted.
|
||||||
func (s *outputRoomEventsStatements) SelectEvents(
|
func (s *outputRoomEventsStatements) SelectEvents(
|
||||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool,
|
||||||
) ([]types.StreamEvent, error) {
|
) ([]types.StreamEvent, error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectEventsStmt)
|
var (
|
||||||
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
stmt *sql.Stmt
|
||||||
|
rows *sql.Rows
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if filter == nil {
|
||||||
|
stmt = sqlutil.TxStmt(txn, s.selectEventsStmt)
|
||||||
|
rows, err = stmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||||
|
} else {
|
||||||
|
senders, notSenders := getSendersRoomEventFilter(filter)
|
||||||
|
stmt = sqlutil.TxStmt(txn, s.selectEventsWitFilterStmt)
|
||||||
|
rows, err = stmt.QueryContext(ctx,
|
||||||
|
pq.StringArray(eventIDs),
|
||||||
|
pq.StringArray(senders),
|
||||||
|
pq.StringArray(notSenders),
|
||||||
|
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
|
||||||
|
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
|
||||||
|
filter.ContainsURL,
|
||||||
|
filter.Limit,
|
||||||
|
)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
|
||||||
return rowsToStreamEvents(rows)
|
streamEvents, err := rowsToStreamEvents(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if preserveOrder {
|
||||||
|
eventMap := make(map[string]types.StreamEvent)
|
||||||
|
for _, ev := range streamEvents {
|
||||||
|
eventMap[ev.EventID()] = ev
|
||||||
|
}
|
||||||
|
var returnEvents []types.StreamEvent
|
||||||
|
for _, eventID := range eventIDs {
|
||||||
|
ev, ok := eventMap[eventID]
|
||||||
|
if ok {
|
||||||
|
returnEvents = append(returnEvents, ev)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return returnEvents, nil
|
||||||
|
}
|
||||||
|
return streamEvents, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *outputRoomEventsStatements) DeleteEventsForRoom(
|
func (s *outputRoomEventsStatements) DeleteEventsForRoom(
|
||||||
|
|
@ -462,10 +512,11 @@ func (s *outputRoomEventsStatements) SelectContextEvent(ctx context.Context, txn
|
||||||
func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
|
func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
|
||||||
ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter,
|
ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter,
|
||||||
) (evts []*gomatrixserverlib.HeaderedEvent, err error) {
|
) (evts []*gomatrixserverlib.HeaderedEvent, err error) {
|
||||||
|
senders, notSenders := getSendersRoomEventFilter(filter)
|
||||||
rows, err := sqlutil.TxStmt(txn, s.selectContextBeforeEventStmt).QueryContext(
|
rows, err := sqlutil.TxStmt(txn, s.selectContextBeforeEventStmt).QueryContext(
|
||||||
ctx, roomID, id, filter.Limit,
|
ctx, roomID, id, filter.Limit,
|
||||||
pq.StringArray(filter.Senders),
|
pq.StringArray(senders),
|
||||||
pq.StringArray(filter.NotSenders),
|
pq.StringArray(notSenders),
|
||||||
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
|
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
|
||||||
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
|
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
|
||||||
)
|
)
|
||||||
|
|
@ -494,10 +545,11 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
|
||||||
func (s *outputRoomEventsStatements) SelectContextAfterEvent(
|
func (s *outputRoomEventsStatements) SelectContextAfterEvent(
|
||||||
ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter,
|
ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter,
|
||||||
) (lastID int, evts []*gomatrixserverlib.HeaderedEvent, err error) {
|
) (lastID int, evts []*gomatrixserverlib.HeaderedEvent, err error) {
|
||||||
|
senders, notSenders := getSendersRoomEventFilter(filter)
|
||||||
rows, err := sqlutil.TxStmt(txn, s.selectContextAfterEventStmt).QueryContext(
|
rows, err := sqlutil.TxStmt(txn, s.selectContextAfterEventStmt).QueryContext(
|
||||||
ctx, roomID, id, filter.Limit,
|
ctx, roomID, id, filter.Limit,
|
||||||
pq.StringArray(filter.Senders),
|
pq.StringArray(senders),
|
||||||
pq.StringArray(filter.NotSenders),
|
pq.StringArray(notSenders),
|
||||||
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
|
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
|
||||||
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
|
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -73,9 +73,6 @@ const selectMaxPositionInTopologySQL = "" +
|
||||||
"SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" +
|
"SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" +
|
||||||
") ORDER BY stream_position DESC LIMIT 1"
|
") ORDER BY stream_position DESC LIMIT 1"
|
||||||
|
|
||||||
const deleteTopologyForRoomSQL = "" +
|
|
||||||
"DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1"
|
|
||||||
|
|
||||||
const selectStreamToTopologicalPositionAscSQL = "" +
|
const selectStreamToTopologicalPositionAscSQL = "" +
|
||||||
"SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;"
|
"SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;"
|
||||||
|
|
||||||
|
|
@ -88,7 +85,6 @@ type outputRoomEventsTopologyStatements struct {
|
||||||
selectEventIDsInRangeDESCStmt *sql.Stmt
|
selectEventIDsInRangeDESCStmt *sql.Stmt
|
||||||
selectPositionInTopologyStmt *sql.Stmt
|
selectPositionInTopologyStmt *sql.Stmt
|
||||||
selectMaxPositionInTopologyStmt *sql.Stmt
|
selectMaxPositionInTopologyStmt *sql.Stmt
|
||||||
deleteTopologyForRoomStmt *sql.Stmt
|
|
||||||
selectStreamToTopologicalPositionAscStmt *sql.Stmt
|
selectStreamToTopologicalPositionAscStmt *sql.Stmt
|
||||||
selectStreamToTopologicalPositionDescStmt *sql.Stmt
|
selectStreamToTopologicalPositionDescStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
@ -114,9 +110,6 @@ func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) {
|
||||||
if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil {
|
if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if s.deleteTopologyForRoomStmt, err = db.Prepare(deleteTopologyForRoomSQL); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil {
|
if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -148,9 +141,9 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
|
||||||
// is requested or not.
|
// is requested or not.
|
||||||
var stmt *sql.Stmt
|
var stmt *sql.Stmt
|
||||||
if chronologicalOrder {
|
if chronologicalOrder {
|
||||||
stmt = s.selectEventIDsInRangeASCStmt
|
stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeASCStmt)
|
||||||
} else {
|
} else {
|
||||||
stmt = s.selectEventIDsInRangeDESCStmt
|
stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeDESCStmt)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query the event IDs.
|
// Query the event IDs.
|
||||||
|
|
@ -203,10 +196,3 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
|
||||||
err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
|
err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom(
|
|
||||||
ctx context.Context, txn *sql.Tx, roomID string,
|
|
||||||
) (err error) {
|
|
||||||
_, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -150,7 +150,7 @@ func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, stre
|
||||||
// Returns an error if there was a problem talking with the database.
|
// Returns an error if there was a problem talking with the database.
|
||||||
// Does not include any transaction IDs in the returned events.
|
// Does not include any transaction IDs in the returned events.
|
||||||
func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||||
streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs)
|
streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, nil, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -312,7 +312,7 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e
|
||||||
|
|
||||||
// Check if we have all of the event's previous events. If an event is
|
// Check if we have all of the event's previous events. If an event is
|
||||||
// missing, add it to the room's backward extremities.
|
// missing, add it to the room's backward extremities.
|
||||||
prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs())
|
prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), nil, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -429,7 +429,8 @@ func (d *Database) updateRoomState(
|
||||||
func (d *Database) GetEventsInTopologicalRange(
|
func (d *Database) GetEventsInTopologicalRange(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
from, to *types.TopologyToken,
|
from, to *types.TopologyToken,
|
||||||
roomID string, limit int,
|
roomID string,
|
||||||
|
filter *gomatrixserverlib.RoomEventFilter,
|
||||||
backwardOrdering bool,
|
backwardOrdering bool,
|
||||||
) (events []types.StreamEvent, err error) {
|
) (events []types.StreamEvent, err error) {
|
||||||
var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition
|
var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition
|
||||||
|
|
@ -450,14 +451,14 @@ func (d *Database) GetEventsInTopologicalRange(
|
||||||
// Select the event IDs from the defined range.
|
// Select the event IDs from the defined range.
|
||||||
var eIDs []string
|
var eIDs []string
|
||||||
eIDs, err = d.Topology.SelectEventIDsInRange(
|
eIDs, err = d.Topology.SelectEventIDsInRange(
|
||||||
ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, limit, !backwardOrdering,
|
ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Retrieve the events' contents using their IDs.
|
// Retrieve the events' contents using their IDs.
|
||||||
events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs)
|
events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs, filter, true)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -619,7 +620,7 @@ func (d *Database) fetchMissingStateEvents(
|
||||||
) ([]types.StreamEvent, error) {
|
) ([]types.StreamEvent, error) {
|
||||||
// Fetch from the events table first so we pick up the stream ID for the
|
// Fetch from the events table first so we pick up the stream ID for the
|
||||||
// event.
|
// event.
|
||||||
events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs)
|
events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, nil, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -687,6 +688,9 @@ func (d *Database) GetStateDeltas(
|
||||||
// user has ever interacted with — joined to, kicked/banned from, left.
|
// user has ever interacted with — joined to, kicked/banned from, left.
|
||||||
memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID)
|
memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -704,17 +708,23 @@ func (d *Database) GetStateDeltas(
|
||||||
// get all the state events ever (i.e. for all available rooms) between these two positions
|
// get all the state events ever (i.e. for all available rooms) between these two positions
|
||||||
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs)
|
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap)
|
state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// find out which rooms this user is peeking, if any.
|
// find out which rooms this user is peeking, if any.
|
||||||
// We do this before joins so any peeks get overwritten
|
// We do this before joins so any peeks get overwritten
|
||||||
peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r)
|
peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r)
|
||||||
if err != nil {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -725,6 +735,9 @@ func (d *Database) GetStateDeltas(
|
||||||
var s []types.StreamEvent
|
var s []types.StreamEvent
|
||||||
s, err = d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter)
|
s, err = d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
continue
|
||||||
|
}
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
state[peek.RoomID] = s
|
state[peek.RoomID] = s
|
||||||
|
|
@ -752,6 +765,9 @@ func (d *Database) GetStateDeltas(
|
||||||
var s []types.StreamEvent
|
var s []types.StreamEvent
|
||||||
s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter)
|
s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
continue
|
||||||
|
}
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
state[roomID] = s
|
state[roomID] = s
|
||||||
|
|
@ -802,6 +818,9 @@ func (d *Database) GetStateDeltasForFullStateSync(
|
||||||
// user has ever interacted with — joined to, kicked/banned from, left.
|
// user has ever interacted with — joined to, kicked/banned from, left.
|
||||||
memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID)
|
memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -818,7 +837,7 @@ func (d *Database) GetStateDeltasForFullStateSync(
|
||||||
deltas := make(map[string]types.StateDelta)
|
deltas := make(map[string]types.StateDelta)
|
||||||
|
|
||||||
peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r)
|
peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r)
|
||||||
if err != nil {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -827,6 +846,9 @@ func (d *Database) GetStateDeltasForFullStateSync(
|
||||||
if !peek.Deleted {
|
if !peek.Deleted {
|
||||||
s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter)
|
s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter)
|
||||||
if stateErr != nil {
|
if stateErr != nil {
|
||||||
|
if stateErr == sql.ErrNoRows {
|
||||||
|
continue
|
||||||
|
}
|
||||||
return nil, nil, stateErr
|
return nil, nil, stateErr
|
||||||
}
|
}
|
||||||
deltas[peek.RoomID] = types.StateDelta{
|
deltas[peek.RoomID] = types.StateDelta{
|
||||||
|
|
@ -840,10 +862,16 @@ func (d *Database) GetStateDeltasForFullStateSync(
|
||||||
// Get all the state events ever between these two positions
|
// Get all the state events ever between these two positions
|
||||||
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs)
|
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap)
|
state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -868,6 +896,9 @@ func (d *Database) GetStateDeltasForFullStateSync(
|
||||||
for _, joinedRoomID := range joinedRoomIDs {
|
for _, joinedRoomID := range joinedRoomIDs {
|
||||||
s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilter)
|
s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilter)
|
||||||
if stateErr != nil {
|
if stateErr != nil {
|
||||||
|
if stateErr == sql.ErrNoRows {
|
||||||
|
continue
|
||||||
|
}
|
||||||
return nil, nil, stateErr
|
return nil, nil, stateErr
|
||||||
}
|
}
|
||||||
deltas[joinedRoomID] = types.StateDelta{
|
deltas[joinedRoomID] = types.StateDelta{
|
||||||
|
|
|
||||||
|
|
@ -41,23 +41,23 @@ const insertAccountDataSQL = "" +
|
||||||
" ON CONFLICT (user_id, room_id, type) DO UPDATE" +
|
" ON CONFLICT (user_id, room_id, type) DO UPDATE" +
|
||||||
" SET id = $5"
|
" SET id = $5"
|
||||||
|
|
||||||
|
// further parameters are added by prepareWithFilters
|
||||||
const selectAccountDataInRangeSQL = "" +
|
const selectAccountDataInRangeSQL = "" +
|
||||||
"SELECT room_id, type FROM syncapi_account_data_type" +
|
"SELECT room_id, type FROM syncapi_account_data_type" +
|
||||||
" WHERE user_id = $1 AND id > $2 AND id <= $3" +
|
" WHERE user_id = $1 AND id > $2 AND id <= $3"
|
||||||
" ORDER BY id ASC"
|
|
||||||
|
|
||||||
const selectMaxAccountDataIDSQL = "" +
|
const selectMaxAccountDataIDSQL = "" +
|
||||||
"SELECT MAX(id) FROM syncapi_account_data_type"
|
"SELECT MAX(id) FROM syncapi_account_data_type"
|
||||||
|
|
||||||
type accountDataStatements struct {
|
type accountDataStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
streamIDStatements *streamIDStatements
|
streamIDStatements *StreamIDStatements
|
||||||
insertAccountDataStmt *sql.Stmt
|
insertAccountDataStmt *sql.Stmt
|
||||||
selectMaxAccountDataIDStmt *sql.Stmt
|
selectMaxAccountDataIDStmt *sql.Stmt
|
||||||
selectAccountDataInRangeStmt *sql.Stmt
|
selectAccountDataInRangeStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) {
|
func NewSqliteAccountDataTable(db *sql.DB, streamID *StreamIDStatements) (tables.AccountData, error) {
|
||||||
s := &accountDataStatements{
|
s := &accountDataStatements{
|
||||||
db: db,
|
db: db,
|
||||||
streamIDStatements: streamID,
|
streamIDStatements: streamID,
|
||||||
|
|
@ -94,18 +94,24 @@ func (s *accountDataStatements) SelectAccountDataInRange(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
userID string,
|
userID string,
|
||||||
r types.Range,
|
r types.Range,
|
||||||
accountDataFilterPart *gomatrixserverlib.EventFilter,
|
filter *gomatrixserverlib.EventFilter,
|
||||||
) (data map[string][]string, err error) {
|
) (data map[string][]string, err error) {
|
||||||
data = make(map[string][]string)
|
data = make(map[string][]string)
|
||||||
|
stmt, params, err := prepareWithFilters(
|
||||||
|
s.db, nil, selectAccountDataInRangeSQL,
|
||||||
|
[]interface{}{
|
||||||
|
userID, r.Low(), r.High(),
|
||||||
|
},
|
||||||
|
filter.Senders, filter.NotSenders,
|
||||||
|
filter.Types, filter.NotTypes,
|
||||||
|
[]string{}, nil, filter.Limit, FilterOrderAsc)
|
||||||
|
|
||||||
rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High())
|
rows, err := stmt.QueryContext(ctx, params...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed")
|
||||||
|
|
||||||
var entries int
|
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var dataType string
|
var dataType string
|
||||||
var roomID string
|
var roomID string
|
||||||
|
|
@ -114,31 +120,11 @@ func (s *accountDataStatements) SelectAccountDataInRange(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if we should add this by looking at the filter.
|
|
||||||
// It would be nice if we could do this in SQL-land, but the mix of variadic
|
|
||||||
// and positional parameters makes the query annoyingly hard to do, it's easier
|
|
||||||
// and clearer to do it in Go-land. If there are no filters for [not]types then
|
|
||||||
// this gets skipped.
|
|
||||||
for _, includeType := range accountDataFilterPart.Types {
|
|
||||||
if includeType != dataType { // TODO: wildcard support
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, excludeType := range accountDataFilterPart.NotTypes {
|
|
||||||
if excludeType == dataType { // TODO: wildcard support
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(data[roomID]) > 0 {
|
if len(data[roomID]) > 0 {
|
||||||
data[roomID] = append(data[roomID], dataType)
|
data[roomID] = append(data[roomID], dataType)
|
||||||
} else {
|
} else {
|
||||||
data[roomID] = []string{dataType}
|
data[roomID] = []string{dataType}
|
||||||
}
|
}
|
||||||
entries++
|
|
||||||
if entries >= accountDataFilterPart.Limit {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return data, nil
|
return data, nil
|
||||||
|
|
|
||||||
|
|
@ -47,15 +47,11 @@ const selectBackwardExtremitiesForRoomSQL = "" +
|
||||||
const deleteBackwardExtremitySQL = "" +
|
const deleteBackwardExtremitySQL = "" +
|
||||||
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
|
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
|
||||||
|
|
||||||
const deleteBackwardExtremitiesForRoomSQL = "" +
|
|
||||||
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1"
|
|
||||||
|
|
||||||
type backwardExtremitiesStatements struct {
|
type backwardExtremitiesStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertBackwardExtremityStmt *sql.Stmt
|
insertBackwardExtremityStmt *sql.Stmt
|
||||||
selectBackwardExtremitiesForRoomStmt *sql.Stmt
|
selectBackwardExtremitiesForRoomStmt *sql.Stmt
|
||||||
deleteBackwardExtremityStmt *sql.Stmt
|
deleteBackwardExtremityStmt *sql.Stmt
|
||||||
deleteBackwardExtremitiesForRoomStmt *sql.Stmt
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
|
func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
|
||||||
|
|
@ -75,9 +71,6 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities
|
||||||
if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil {
|
if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -116,10 +109,3 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
|
||||||
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
|
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom(
|
|
||||||
ctx context.Context, txn *sql.Tx, roomID string,
|
|
||||||
) (err error) {
|
|
||||||
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -90,7 +90,7 @@ const selectEventsWithEventIDsSQL = "" +
|
||||||
|
|
||||||
type currentRoomStateStatements struct {
|
type currentRoomStateStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
streamIDStatements *streamIDStatements
|
streamIDStatements *StreamIDStatements
|
||||||
upsertRoomStateStmt *sql.Stmt
|
upsertRoomStateStmt *sql.Stmt
|
||||||
deleteRoomStateByEventIDStmt *sql.Stmt
|
deleteRoomStateByEventIDStmt *sql.Stmt
|
||||||
deleteRoomStateForRoomStmt *sql.Stmt
|
deleteRoomStateForRoomStmt *sql.Stmt
|
||||||
|
|
@ -100,7 +100,7 @@ type currentRoomStateStatements struct {
|
||||||
selectStateEventStmt *sql.Stmt
|
selectStateEventStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
|
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
|
||||||
s := ¤tRoomStateStatements{
|
s := ¤tRoomStateStatements{
|
||||||
db: db,
|
db: db,
|
||||||
streamIDStatements: streamID,
|
streamIDStatements: streamID,
|
||||||
|
|
@ -220,7 +220,7 @@ func (s *currentRoomStateStatements) SelectCurrentState(
|
||||||
},
|
},
|
||||||
stateFilter.Senders, stateFilter.NotSenders,
|
stateFilter.Senders, stateFilter.NotSenders,
|
||||||
stateFilter.Types, stateFilter.NotTypes,
|
stateFilter.Types, stateFilter.NotTypes,
|
||||||
excludeEventIDs, stateFilter.Limit, FilterOrderNone,
|
excludeEventIDs, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderNone,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
|
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
|
||||||
|
|
|
||||||
|
|
@ -25,34 +25,53 @@ const (
|
||||||
// parts.
|
// parts.
|
||||||
func prepareWithFilters(
|
func prepareWithFilters(
|
||||||
db *sql.DB, txn *sql.Tx, query string, params []interface{},
|
db *sql.DB, txn *sql.Tx, query string, params []interface{},
|
||||||
senders, notsenders, types, nottypes []string, excludeEventIDs []string,
|
senders, notsenders, types, nottypes *[]string, excludeEventIDs []string,
|
||||||
limit int, order FilterOrder,
|
containsURL *bool, limit int, order FilterOrder,
|
||||||
) (*sql.Stmt, []interface{}, error) {
|
) (*sql.Stmt, []interface{}, error) {
|
||||||
offset := len(params)
|
offset := len(params)
|
||||||
if count := len(senders); count > 0 {
|
if senders != nil {
|
||||||
query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset)
|
if count := len(*senders); count > 0 {
|
||||||
for _, v := range senders {
|
query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||||
params, offset = append(params, v), offset+1
|
for _, v := range *senders {
|
||||||
|
params, offset = append(params, v), offset+1
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
query += ` AND sender = ""`
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if count := len(notsenders); count > 0 {
|
if notsenders != nil {
|
||||||
query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
|
if count := len(*notsenders); count > 0 {
|
||||||
for _, v := range notsenders {
|
query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||||
params, offset = append(params, v), offset+1
|
for _, v := range *notsenders {
|
||||||
|
params, offset = append(params, v), offset+1
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
query += ` AND sender NOT = ""`
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if count := len(types); count > 0 {
|
if types != nil {
|
||||||
query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset)
|
if count := len(*types); count > 0 {
|
||||||
for _, v := range types {
|
query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||||
params, offset = append(params, v), offset+1
|
for _, v := range *types {
|
||||||
|
params, offset = append(params, v), offset+1
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
query += ` AND type = ""`
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if count := len(nottypes); count > 0 {
|
if nottypes != nil {
|
||||||
query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
|
if count := len(*nottypes); count > 0 {
|
||||||
for _, v := range nottypes {
|
query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||||
params, offset = append(params, v), offset+1
|
for _, v := range *nottypes {
|
||||||
|
params, offset = append(params, v), offset+1
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
query += ` AND type NOT = ""`
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if containsURL != nil {
|
||||||
|
query += fmt.Sprintf(" AND contains_url = %v", *containsURL)
|
||||||
|
}
|
||||||
if count := len(excludeEventIDs); count > 0 {
|
if count := len(excludeEventIDs); count > 0 {
|
||||||
query += " AND event_id NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
|
query += " AND event_id NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||||
for _, v := range excludeEventIDs {
|
for _, v := range excludeEventIDs {
|
||||||
|
|
|
||||||
|
|
@ -59,14 +59,14 @@ const selectMaxInviteIDSQL = "" +
|
||||||
|
|
||||||
type inviteEventsStatements struct {
|
type inviteEventsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
streamIDStatements *streamIDStatements
|
streamIDStatements *StreamIDStatements
|
||||||
insertInviteEventStmt *sql.Stmt
|
insertInviteEventStmt *sql.Stmt
|
||||||
selectInviteEventsInRangeStmt *sql.Stmt
|
selectInviteEventsInRangeStmt *sql.Stmt
|
||||||
deleteInviteEventStmt *sql.Stmt
|
deleteInviteEventStmt *sql.Stmt
|
||||||
selectMaxInviteIDStmt *sql.Stmt
|
selectMaxInviteIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) {
|
func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Invites, error) {
|
||||||
s := &inviteEventsStatements{
|
s := &inviteEventsStatements{
|
||||||
db: db,
|
db: db,
|
||||||
streamIDStatements: streamID,
|
streamIDStatements: streamID,
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||||
|
|
@ -57,12 +56,6 @@ const upsertMembershipSQL = "" +
|
||||||
" ON CONFLICT (room_id, user_id, membership)" +
|
" ON CONFLICT (room_id, user_id, membership)" +
|
||||||
" DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6"
|
" DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6"
|
||||||
|
|
||||||
const selectMembershipSQL = "" +
|
|
||||||
"SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" +
|
|
||||||
" WHERE room_id = $1 AND user_id = $2 AND membership IN ($3)" +
|
|
||||||
" ORDER BY stream_pos DESC" +
|
|
||||||
" LIMIT 1"
|
|
||||||
|
|
||||||
const selectMembershipCountSQL = "" +
|
const selectMembershipCountSQL = "" +
|
||||||
"SELECT COUNT(*) FROM (" +
|
"SELECT COUNT(*) FROM (" +
|
||||||
" SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" +
|
" SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" +
|
||||||
|
|
@ -111,22 +104,6 @@ func (s *membershipsStatements) UpsertMembership(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipsStatements) SelectMembership(
|
|
||||||
ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string,
|
|
||||||
) (eventID string, streamPos, topologyPos types.StreamPosition, err error) {
|
|
||||||
params := []interface{}{roomID, userID}
|
|
||||||
for _, membership := range memberships {
|
|
||||||
params = append(params, membership)
|
|
||||||
}
|
|
||||||
orig := strings.Replace(selectMembershipSQL, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1)
|
|
||||||
stmt, err := s.db.Prepare(orig)
|
|
||||||
if err != nil {
|
|
||||||
return "", 0, 0, err
|
|
||||||
}
|
|
||||||
err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *membershipsStatements) SelectMembershipCount(
|
func (s *membershipsStatements) SelectMembershipCount(
|
||||||
ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition,
|
ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition,
|
||||||
) (count int, err error) {
|
) (count int, err error) {
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,7 @@ const insertEventSQL = "" +
|
||||||
"ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)"
|
"ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)"
|
||||||
|
|
||||||
const selectEventsSQL = "" +
|
const selectEventsSQL = "" +
|
||||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1"
|
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id IN ($1)"
|
||||||
|
|
||||||
const selectRecentEventsSQL = "" +
|
const selectRecentEventsSQL = "" +
|
||||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
|
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
|
||||||
|
|
@ -111,9 +111,8 @@ const selectContextAfterEventSQL = "" +
|
||||||
|
|
||||||
type outputRoomEventsStatements struct {
|
type outputRoomEventsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
streamIDStatements *streamIDStatements
|
streamIDStatements *StreamIDStatements
|
||||||
insertEventStmt *sql.Stmt
|
insertEventStmt *sql.Stmt
|
||||||
selectEventsStmt *sql.Stmt
|
|
||||||
selectMaxEventIDStmt *sql.Stmt
|
selectMaxEventIDStmt *sql.Stmt
|
||||||
updateEventJSONStmt *sql.Stmt
|
updateEventJSONStmt *sql.Stmt
|
||||||
deleteEventsForRoomStmt *sql.Stmt
|
deleteEventsForRoomStmt *sql.Stmt
|
||||||
|
|
@ -122,7 +121,7 @@ type outputRoomEventsStatements struct {
|
||||||
selectContextAfterEventStmt *sql.Stmt
|
selectContextAfterEventStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
|
func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Events, error) {
|
||||||
s := &outputRoomEventsStatements{
|
s := &outputRoomEventsStatements{
|
||||||
db: db,
|
db: db,
|
||||||
streamIDStatements: streamID,
|
streamIDStatements: streamID,
|
||||||
|
|
@ -133,7 +132,6 @@ func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Even
|
||||||
}
|
}
|
||||||
return s, sqlutil.StatementList{
|
return s, sqlutil.StatementList{
|
||||||
{&s.insertEventStmt, insertEventSQL},
|
{&s.insertEventStmt, insertEventSQL},
|
||||||
{&s.selectEventsStmt, selectEventsSQL},
|
|
||||||
{&s.selectMaxEventIDStmt, selectMaxEventIDSQL},
|
{&s.selectMaxEventIDStmt, selectMaxEventIDSQL},
|
||||||
{&s.updateEventJSONStmt, updateEventJSONSQL},
|
{&s.updateEventJSONStmt, updateEventJSONSQL},
|
||||||
{&s.deleteEventsForRoomStmt, deleteEventsForRoomSQL},
|
{&s.deleteEventsForRoomStmt, deleteEventsForRoomSQL},
|
||||||
|
|
@ -170,7 +168,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
|
||||||
s.db, txn, stmtSQL, inputParams,
|
s.db, txn, stmtSQL, inputParams,
|
||||||
stateFilter.Senders, stateFilter.NotSenders,
|
stateFilter.Senders, stateFilter.NotSenders,
|
||||||
stateFilter.Types, stateFilter.NotTypes,
|
stateFilter.Types, stateFilter.NotTypes,
|
||||||
nil, stateFilter.Limit, FilterOrderAsc,
|
nil, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderAsc,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err)
|
return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err)
|
||||||
|
|
@ -279,7 +277,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
|
||||||
// Parse content as JSON and search for an "url" key
|
// Parse content as JSON and search for an "url" key
|
||||||
containsURL := false
|
containsURL := false
|
||||||
var content map[string]interface{}
|
var content map[string]interface{}
|
||||||
if json.Unmarshal(event.Content(), &content) != nil {
|
if json.Unmarshal(event.Content(), &content) == nil {
|
||||||
// Set containsURL to true if url is present
|
// Set containsURL to true if url is present
|
||||||
_, containsURL = content["url"]
|
_, containsURL = content["url"]
|
||||||
}
|
}
|
||||||
|
|
@ -347,7 +345,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
|
||||||
},
|
},
|
||||||
eventFilter.Senders, eventFilter.NotSenders,
|
eventFilter.Senders, eventFilter.NotSenders,
|
||||||
eventFilter.Types, eventFilter.NotTypes,
|
eventFilter.Types, eventFilter.NotTypes,
|
||||||
nil, eventFilter.Limit+1, FilterOrderDesc,
|
nil, eventFilter.ContainsURL, eventFilter.Limit+1, FilterOrderDesc,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err)
|
return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err)
|
||||||
|
|
@ -395,7 +393,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
|
||||||
},
|
},
|
||||||
eventFilter.Senders, eventFilter.NotSenders,
|
eventFilter.Senders, eventFilter.NotSenders,
|
||||||
eventFilter.Types, eventFilter.NotTypes,
|
eventFilter.Types, eventFilter.NotTypes,
|
||||||
nil, eventFilter.Limit, FilterOrderAsc,
|
nil, eventFilter.ContainsURL, eventFilter.Limit, FilterOrderAsc,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
|
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
|
||||||
|
|
@ -421,21 +419,50 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
|
||||||
// selectEvents returns the events for the given event IDs. If an event is
|
// selectEvents returns the events for the given event IDs. If an event is
|
||||||
// missing from the database, it will be omitted.
|
// missing from the database, it will be omitted.
|
||||||
func (s *outputRoomEventsStatements) SelectEvents(
|
func (s *outputRoomEventsStatements) SelectEvents(
|
||||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool,
|
||||||
) ([]types.StreamEvent, error) {
|
) ([]types.StreamEvent, error) {
|
||||||
var returnEvents []types.StreamEvent
|
iEventIDs := make([]interface{}, len(eventIDs))
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectEventsStmt)
|
for i := range eventIDs {
|
||||||
for _, eventID := range eventIDs {
|
iEventIDs[i] = eventIDs[i]
|
||||||
rows, err := stmt.QueryContext(ctx, eventID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if streamEvents, err := rowsToStreamEvents(rows); err == nil {
|
|
||||||
returnEvents = append(returnEvents, streamEvents...)
|
|
||||||
}
|
|
||||||
internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
|
|
||||||
}
|
}
|
||||||
return returnEvents, nil
|
selectSQL := strings.Replace(selectEventsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1)
|
||||||
|
|
||||||
|
if filter == nil {
|
||||||
|
filter = &gomatrixserverlib.RoomEventFilter{Limit: 20}
|
||||||
|
}
|
||||||
|
stmt, params, err := prepareWithFilters(
|
||||||
|
s.db, txn, selectSQL, iEventIDs,
|
||||||
|
filter.Senders, filter.NotSenders,
|
||||||
|
filter.Types, filter.NotTypes,
|
||||||
|
nil, filter.ContainsURL, filter.Limit, FilterOrderAsc,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
rows, err := stmt.QueryContext(ctx, params...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
|
||||||
|
streamEvents, err := rowsToStreamEvents(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if preserveOrder {
|
||||||
|
var returnEvents []types.StreamEvent
|
||||||
|
eventMap := make(map[string]types.StreamEvent)
|
||||||
|
for _, ev := range streamEvents {
|
||||||
|
eventMap[ev.EventID()] = ev
|
||||||
|
}
|
||||||
|
for _, eventID := range eventIDs {
|
||||||
|
ev, ok := eventMap[eventID]
|
||||||
|
if ok {
|
||||||
|
returnEvents = append(returnEvents, ev)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return returnEvents, nil
|
||||||
|
}
|
||||||
|
return streamEvents, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *outputRoomEventsStatements) DeleteEventsForRoom(
|
func (s *outputRoomEventsStatements) DeleteEventsForRoom(
|
||||||
|
|
@ -507,7 +534,7 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
|
||||||
},
|
},
|
||||||
filter.Senders, filter.NotSenders,
|
filter.Senders, filter.NotSenders,
|
||||||
filter.Types, filter.NotTypes,
|
filter.Types, filter.NotTypes,
|
||||||
nil, filter.Limit, FilterOrderDesc,
|
nil, filter.ContainsURL, filter.Limit, FilterOrderDesc,
|
||||||
)
|
)
|
||||||
|
|
||||||
rows, err := stmt.QueryContext(ctx, params...)
|
rows, err := stmt.QueryContext(ctx, params...)
|
||||||
|
|
@ -543,7 +570,7 @@ func (s *outputRoomEventsStatements) SelectContextAfterEvent(
|
||||||
},
|
},
|
||||||
filter.Senders, filter.NotSenders,
|
filter.Senders, filter.NotSenders,
|
||||||
filter.Types, filter.NotTypes,
|
filter.Types, filter.NotTypes,
|
||||||
nil, filter.Limit, FilterOrderAsc,
|
nil, filter.ContainsURL, filter.Limit, FilterOrderAsc,
|
||||||
)
|
)
|
||||||
|
|
||||||
rows, err := stmt.QueryContext(ctx, params...)
|
rows, err := stmt.QueryContext(ctx, params...)
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,6 @@ type outputRoomEventsTopologyStatements struct {
|
||||||
selectEventIDsInRangeDESCStmt *sql.Stmt
|
selectEventIDsInRangeDESCStmt *sql.Stmt
|
||||||
selectPositionInTopologyStmt *sql.Stmt
|
selectPositionInTopologyStmt *sql.Stmt
|
||||||
selectMaxPositionInTopologyStmt *sql.Stmt
|
selectMaxPositionInTopologyStmt *sql.Stmt
|
||||||
deleteTopologyForRoomStmt *sql.Stmt
|
|
||||||
selectStreamToTopologicalPositionAscStmt *sql.Stmt
|
selectStreamToTopologicalPositionAscStmt *sql.Stmt
|
||||||
selectStreamToTopologicalPositionDescStmt *sql.Stmt
|
selectStreamToTopologicalPositionDescStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
@ -191,10 +190,3 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
|
||||||
err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
|
err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom(
|
|
||||||
ctx context.Context, txn *sql.Tx, roomID string,
|
|
||||||
) (err error) {
|
|
||||||
_, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -66,7 +66,7 @@ const selectMaxPeekIDSQL = "" +
|
||||||
|
|
||||||
type peekStatements struct {
|
type peekStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
streamIDStatements *streamIDStatements
|
streamIDStatements *StreamIDStatements
|
||||||
insertPeekStmt *sql.Stmt
|
insertPeekStmt *sql.Stmt
|
||||||
deletePeekStmt *sql.Stmt
|
deletePeekStmt *sql.Stmt
|
||||||
deletePeeksStmt *sql.Stmt
|
deletePeeksStmt *sql.Stmt
|
||||||
|
|
@ -75,7 +75,7 @@ type peekStatements struct {
|
||||||
selectMaxPeekIDStmt *sql.Stmt
|
selectMaxPeekIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks, error) {
|
func NewSqlitePeeksTable(db *sql.DB, streamID *StreamIDStatements) (tables.Peeks, error) {
|
||||||
_, err := db.Exec(peeksSchema)
|
_, err := db.Exec(peeksSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
||||||
|
|
@ -75,7 +75,7 @@ const selectPresenceAfter = "" +
|
||||||
|
|
||||||
type presenceStatements struct {
|
type presenceStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
streamIDStatements *streamIDStatements
|
streamIDStatements *StreamIDStatements
|
||||||
upsertPresenceStmt *sql.Stmt
|
upsertPresenceStmt *sql.Stmt
|
||||||
upsertPresenceFromSyncStmt *sql.Stmt
|
upsertPresenceFromSyncStmt *sql.Stmt
|
||||||
selectPresenceForUsersStmt *sql.Stmt
|
selectPresenceForUsersStmt *sql.Stmt
|
||||||
|
|
@ -83,7 +83,7 @@ type presenceStatements struct {
|
||||||
selectPresenceAfterStmt *sql.Stmt
|
selectPresenceAfterStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqlitePresenceTable(db *sql.DB, streamID *streamIDStatements) (*presenceStatements, error) {
|
func NewSqlitePresenceTable(db *sql.DB, streamID *StreamIDStatements) (*presenceStatements, error) {
|
||||||
_, err := db.Exec(presenceSchema)
|
_, err := db.Exec(presenceSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
||||||
|
|
@ -59,13 +59,13 @@ const selectMaxReceiptIDSQL = "" +
|
||||||
|
|
||||||
type receiptStatements struct {
|
type receiptStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
streamIDStatements *streamIDStatements
|
streamIDStatements *StreamIDStatements
|
||||||
upsertReceipt *sql.Stmt
|
upsertReceipt *sql.Stmt
|
||||||
selectRoomReceipts *sql.Stmt
|
selectRoomReceipts *sql.Stmt
|
||||||
selectMaxReceiptID *sql.Stmt
|
selectMaxReceiptID *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Receipts, error) {
|
func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Receipts, error) {
|
||||||
_, err := db.Exec(receiptsSchema)
|
_, err := db.Exec(receiptsSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
||||||
|
|
@ -32,12 +32,12 @@ const increaseStreamIDStmt = "" +
|
||||||
"UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" +
|
"UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" +
|
||||||
" RETURNING stream_id"
|
" RETURNING stream_id"
|
||||||
|
|
||||||
type streamIDStatements struct {
|
type StreamIDStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
increaseStreamIDStmt *sql.Stmt
|
increaseStreamIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
|
func (s *StreamIDStatements) Prepare(db *sql.DB) (err error) {
|
||||||
s.db = db
|
s.db = db
|
||||||
_, err = db.Exec(streamIDTableSchema)
|
_, err = db.Exec(streamIDTableSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -49,31 +49,31 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *streamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
func (s *StreamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||||
err = increaseStmt.QueryRowContext(ctx, "global").Scan(&pos)
|
err = increaseStmt.QueryRowContext(ctx, "global").Scan(&pos)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
func (s *StreamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||||
err = increaseStmt.QueryRowContext(ctx, "receipt").Scan(&pos)
|
err = increaseStmt.QueryRowContext(ctx, "receipt").Scan(&pos)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *streamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
func (s *StreamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||||
err = increaseStmt.QueryRowContext(ctx, "invite").Scan(&pos)
|
err = increaseStmt.QueryRowContext(ctx, "invite").Scan(&pos)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *streamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
func (s *StreamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||||
err = increaseStmt.QueryRowContext(ctx, "accountdata").Scan(&pos)
|
err = increaseStmt.QueryRowContext(ctx, "accountdata").Scan(&pos)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *streamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
func (s *StreamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||||
err = increaseStmt.QueryRowContext(ctx, "presence").Scan(&pos)
|
err = increaseStmt.QueryRowContext(ctx, "presence").Scan(&pos)
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ type SyncServerDatasource struct {
|
||||||
shared.Database
|
shared.Database
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.Writer
|
writer sqlutil.Writer
|
||||||
streamID streamIDStatements
|
streamID StreamIDStatements
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDatabase creates a new sync server database
|
// NewDatabase creates a new sync server database
|
||||||
|
|
@ -49,7 +49,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) {
|
func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) {
|
||||||
if err = d.streamID.prepare(d.db); err != nil {
|
if err = d.streamID.Prepare(d.db); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID)
|
accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID)
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package storage_test
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
|
@ -38,7 +39,7 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("WriteEvent failed: %s", err)
|
t.Fatalf("WriteEvent failed: %s", err)
|
||||||
}
|
}
|
||||||
fmt.Println("Event ID", ev.EventID(), " spos=", pos, "depth=", ev.Depth())
|
t.Logf("Event ID %s spos=%v depth=%v", ev.EventID(), pos, ev.Depth())
|
||||||
positions = append(positions, pos)
|
positions = append(positions, pos)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
@ -46,7 +47,6 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
|
||||||
|
|
||||||
func TestWriteEvents(t *testing.T) {
|
func TestWriteEvents(t *testing.T) {
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
t.Parallel()
|
|
||||||
alice := test.NewUser()
|
alice := test.NewUser()
|
||||||
r := test.NewRoom(t, alice)
|
r := test.NewRoom(t, alice)
|
||||||
db, close := MustCreateDatabase(t, dbType)
|
db, close := MustCreateDatabase(t, dbType)
|
||||||
|
|
@ -61,84 +61,84 @@ func TestRecentEventsPDU(t *testing.T) {
|
||||||
db, close := MustCreateDatabase(t, dbType)
|
db, close := MustCreateDatabase(t, dbType)
|
||||||
defer close()
|
defer close()
|
||||||
alice := test.NewUser()
|
alice := test.NewUser()
|
||||||
var filter gomatrixserverlib.RoomEventFilter
|
// dummy room to make sure SQL queries are filtering on room ID
|
||||||
filter.Limit = 100
|
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
|
||||||
|
|
||||||
|
// actual test room
|
||||||
r := test.NewRoom(t, alice)
|
r := test.NewRoom(t, alice)
|
||||||
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hi"})
|
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hi"})
|
||||||
events := r.Events()
|
events := r.Events()
|
||||||
positions := MustWriteEvents(t, db, events)
|
positions := MustWriteEvents(t, db, events)
|
||||||
|
|
||||||
|
// dummy room to make sure SQL queries are filtering on room ID
|
||||||
|
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
|
||||||
|
|
||||||
latest, err := db.MaxStreamPositionForPDUs(ctx)
|
latest, err := db.MaxStreamPositionForPDUs(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err)
|
t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
Name string
|
Name string
|
||||||
From types.StreamPosition
|
From types.StreamPosition
|
||||||
To types.StreamPosition
|
To types.StreamPosition
|
||||||
WantEvents []*gomatrixserverlib.HeaderedEvent
|
Limit int
|
||||||
WantLimited bool
|
ReverseOrder bool
|
||||||
|
WantEvents []*gomatrixserverlib.HeaderedEvent
|
||||||
|
WantLimited bool
|
||||||
}{
|
}{
|
||||||
// The purpose of this test is to make sure that incremental syncs are including up to the latest events.
|
// The purpose of this test is to make sure that incremental syncs are including up to the latest events.
|
||||||
// It's a basic sanity test that sync works. It creates a `since` token that is on the penultimate event.
|
// It's a basic sanity test that sync works. It creates a streaming position that is on the penultimate event.
|
||||||
// It makes sure the response includes the final event.
|
// It makes sure the response includes the final event.
|
||||||
{
|
{
|
||||||
Name: "IncrementalSync penultimate",
|
Name: "penultimate",
|
||||||
From: positions[len(positions)-2], // pretend we are at the penultimate event
|
From: positions[len(positions)-2], // pretend we are at the penultimate event
|
||||||
To: latest,
|
To: latest,
|
||||||
|
Limit: 100,
|
||||||
WantEvents: events[len(events)-1:],
|
WantEvents: events[len(events)-1:],
|
||||||
WantLimited: false,
|
WantLimited: false,
|
||||||
},
|
},
|
||||||
/*
|
// The purpose of this test is to check that limits can be applied and work.
|
||||||
// The purpose of this test is to check that passing a `numRecentEventsPerRoom` correctly limits the
|
// This is critical for big rooms hence the test here.
|
||||||
// number of returned events. This is critical for big rooms hence the test here.
|
{
|
||||||
{
|
Name: "limited",
|
||||||
Name: "IncrementalSync limited",
|
From: 0,
|
||||||
DoSync: func() (*types.Response, error) {
|
To: latest,
|
||||||
from := types.StreamingToken{ // pretend we are 10 events behind
|
Limit: 1,
|
||||||
PDUPosition: positions[len(positions)-11],
|
WantEvents: events[len(events)-1:],
|
||||||
}
|
WantLimited: true,
|
||||||
res := types.NewResponse()
|
},
|
||||||
// limit is set to 5
|
// The purpose of this test is to check that we can return every event with a high
|
||||||
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
|
// enough limit
|
||||||
},
|
{
|
||||||
// want the last 5 events, NOT the last 10.
|
Name: "large limited",
|
||||||
WantTimeline: events[len(events)-5:],
|
From: 0,
|
||||||
},
|
To: latest,
|
||||||
// The purpose of this test is to check that CompleteSync returns all the current state as well as
|
Limit: 100,
|
||||||
// honouring the `numRecentEventsPerRoom` value
|
WantEvents: events,
|
||||||
{
|
WantLimited: false,
|
||||||
Name: "CompleteSync limited",
|
},
|
||||||
DoSync: func() (*types.Response, error) {
|
// The purpose of this test is to check that we can return events in reverse order
|
||||||
res := types.NewResponse()
|
{
|
||||||
// limit set to 5
|
Name: "reverse",
|
||||||
return db.CompleteSync(ctx, res, testUserDeviceA, 5)
|
From: positions[len(positions)-3], // 2 events back
|
||||||
},
|
To: latest,
|
||||||
// want the last 5 events
|
Limit: 100,
|
||||||
WantTimeline: events[len(events)-5:],
|
ReverseOrder: true,
|
||||||
// want all state for the room
|
WantEvents: test.Reversed(events[len(events)-2:]),
|
||||||
WantState: state,
|
WantLimited: false,
|
||||||
},
|
},
|
||||||
// The purpose of this test is to check that CompleteSync can return everything with a high enough
|
|
||||||
// `numRecentEventsPerRoom`.
|
|
||||||
{
|
|
||||||
Name: "CompleteSync",
|
|
||||||
DoSync: func() (*types.Response, error) {
|
|
||||||
res := types.NewResponse()
|
|
||||||
return db.CompleteSync(ctx, res, testUserDeviceA, len(events)+1)
|
|
||||||
},
|
|
||||||
WantTimeline: events,
|
|
||||||
// We want no state at all as that field in /sync is the delta between the token (beginning of time)
|
|
||||||
// and the START of the timeline.
|
|
||||||
}, */
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for i := range testCases {
|
||||||
|
tc := testCases[i]
|
||||||
t.Run(tc.Name, func(st *testing.T) {
|
t.Run(tc.Name, func(st *testing.T) {
|
||||||
|
var filter gomatrixserverlib.RoomEventFilter
|
||||||
|
filter.Limit = tc.Limit
|
||||||
gotEvents, limited, err := db.RecentEvents(ctx, r.ID, types.Range{
|
gotEvents, limited, err := db.RecentEvents(ctx, r.ID, types.Range{
|
||||||
From: tc.From,
|
From: tc.From,
|
||||||
To: tc.To,
|
To: tc.To,
|
||||||
}, &filter, true, true)
|
}, &filter, !tc.ReverseOrder, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
st.Fatalf("failed to do sync: %s", err)
|
st.Fatalf("failed to do sync: %s", err)
|
||||||
}
|
}
|
||||||
|
|
@ -148,100 +148,49 @@ func TestRecentEventsPDU(t *testing.T) {
|
||||||
if len(gotEvents) != len(tc.WantEvents) {
|
if len(gotEvents) != len(tc.WantEvents) {
|
||||||
st.Errorf("got %d events, want %d", len(gotEvents), len(tc.WantEvents))
|
st.Errorf("got %d events, want %d", len(gotEvents), len(tc.WantEvents))
|
||||||
}
|
}
|
||||||
|
for j := range gotEvents {
|
||||||
|
if !reflect.DeepEqual(gotEvents[j].JSON(), tc.WantEvents[j].JSON()) {
|
||||||
|
st.Errorf("event %d got %s want %s", j, string(gotEvents[j].JSON()), string(tc.WantEvents[j].JSON()))
|
||||||
|
}
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
db := MustCreateDatabase(t)
|
|
||||||
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
|
|
||||||
positions := MustWriteEvents(t, db, events)
|
|
||||||
latest, err := db.SyncPosition(ctx)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to get SyncPosition: %s", err)
|
|
||||||
}
|
|
||||||
from := types.StreamingToken{
|
|
||||||
PDUPosition: positions[len(positions)-2],
|
|
||||||
}
|
|
||||||
|
|
||||||
res := types.NewResponse()
|
|
||||||
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to IncrementalSync with latest token")
|
|
||||||
}
|
|
||||||
roomRes, ok := res.Rooms.Join[testRoomID]
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("IncrementalSync response missing room %s - response: %+v", testRoomID, res)
|
|
||||||
}
|
|
||||||
// returns the last event "Message 10"
|
|
||||||
assertEventsEqual(t, "IncrementalSync Timeline", false, roomRes.Timeline.Events, reversed(events[len(events)-1:]))
|
|
||||||
|
|
||||||
prev := roomRes.Timeline.PrevBatch.String()
|
|
||||||
if prev == "" {
|
|
||||||
t.Fatalf("IncrementalSync expected prev_batch token")
|
|
||||||
}
|
|
||||||
prevBatchToken, err := types.NewTopologyTokenFromString(prev)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to NewTopologyTokenFromString : %s", err)
|
|
||||||
}
|
|
||||||
// backpaginate 5 messages starting at the latest position.
|
|
||||||
// head towards the beginning of time
|
|
||||||
to := types.TopologyToken{}
|
|
||||||
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &prevBatchToken, &to, testRoomID, 5, true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("GetEventsInRange returned an error: %s", err)
|
|
||||||
}
|
|
||||||
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
|
|
||||||
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-6:len(events)-1]))
|
|
||||||
}
|
|
||||||
|
|
||||||
// The purpose of this test is to ensure that backfill does indeed go backwards, using a stream token.
|
|
||||||
func TestGetEventsInRangeWithStreamToken(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
db := MustCreateDatabase(t)
|
|
||||||
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
|
|
||||||
MustWriteEvents(t, db, events)
|
|
||||||
latest, err := db.SyncPosition(ctx)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to get SyncPosition: %s", err)
|
|
||||||
}
|
|
||||||
// head towards the beginning of time
|
|
||||||
to := types.StreamingToken{}
|
|
||||||
|
|
||||||
// backpaginate 5 messages starting at the latest position.
|
|
||||||
paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("GetEventsInRange returned an error: %s", err)
|
|
||||||
}
|
|
||||||
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
|
|
||||||
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:]))
|
|
||||||
}
|
|
||||||
|
|
||||||
// The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token
|
// The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token
|
||||||
func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
|
func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
|
||||||
t.Parallel()
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
db := MustCreateDatabase(t)
|
db, close := MustCreateDatabase(t, dbType)
|
||||||
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
|
defer close()
|
||||||
MustWriteEvents(t, db, events)
|
alice := test.NewUser()
|
||||||
from, err := db.MaxTopologicalPosition(ctx, testRoomID)
|
r := test.NewRoom(t, alice)
|
||||||
if err != nil {
|
for i := 0; i < 10; i++ {
|
||||||
t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
|
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)})
|
||||||
}
|
}
|
||||||
// head towards the beginning of time
|
events := r.Events()
|
||||||
to := types.TopologyToken{}
|
_ = MustWriteEvents(t, db, events)
|
||||||
|
|
||||||
// backpaginate 5 messages starting at the latest position.
|
from, err := db.MaxTopologicalPosition(ctx, r.ID)
|
||||||
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, testRoomID, 5, true)
|
if err != nil {
|
||||||
if err != nil {
|
t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
|
||||||
t.Fatalf("GetEventsInRange returned an error: %s", err)
|
}
|
||||||
}
|
t.Logf("max topo pos = %+v", from)
|
||||||
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
|
// head towards the beginning of time
|
||||||
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:]))
|
to := types.TopologyToken{}
|
||||||
|
|
||||||
|
// backpaginate 5 messages starting at the latest position.
|
||||||
|
filter := &gomatrixserverlib.RoomEventFilter{Limit: 5}
|
||||||
|
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
|
||||||
|
}
|
||||||
|
gots := db.StreamEventsToEvents(nil, paginatedEvents)
|
||||||
|
test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
// The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth.
|
// The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth.
|
||||||
// For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent
|
// For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent
|
||||||
// will appear FIRST when going backwards. This test creates a DAG like:
|
// will appear FIRST when going backwards. This test creates a DAG like:
|
||||||
|
|
@ -651,12 +600,4 @@ func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *typ
|
||||||
tok.Decrement()
|
tok.Decrement()
|
||||||
return &tok
|
return &tok
|
||||||
}
|
}
|
||||||
|
|
||||||
func reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
|
|
||||||
out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
|
|
||||||
for i := 0; i < len(in); i++ {
|
|
||||||
out[i] = in[len(in)-i-1]
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
*/
|
*/
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,7 @@ type Events interface {
|
||||||
SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
|
SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
|
||||||
// SelectEarlyEvents returns the earliest events in the given room.
|
// SelectEarlyEvents returns the earliest events in the given room.
|
||||||
SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error)
|
SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error)
|
||||||
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error)
|
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool) ([]types.StreamEvent, error)
|
||||||
UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error
|
UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error
|
||||||
// DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely.
|
// DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely.
|
||||||
DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
|
DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
|
||||||
|
|
@ -84,8 +84,6 @@ type Topology interface {
|
||||||
SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error)
|
SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error)
|
||||||
// SelectMaxPositionInTopology returns the event which has the highest depth, and if there are multiple, the event with the highest stream position.
|
// SelectMaxPositionInTopology returns the event which has the highest depth, and if there are multiple, the event with the highest stream position.
|
||||||
SelectMaxPositionInTopology(ctx context.Context, txn *sql.Tx, roomID string) (depth types.StreamPosition, spos types.StreamPosition, err error)
|
SelectMaxPositionInTopology(ctx context.Context, txn *sql.Tx, roomID string) (depth types.StreamPosition, spos types.StreamPosition, err error)
|
||||||
// DeleteTopologyForRoom removes all topological information for a room. This should only be done when removing the room entirely.
|
|
||||||
DeleteTopologyForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
|
|
||||||
// SelectStreamToTopologicalPosition converts a stream position to a topological position by finding the nearest topological position in the room.
|
// SelectStreamToTopologicalPosition converts a stream position to a topological position by finding the nearest topological position in the room.
|
||||||
SelectStreamToTopologicalPosition(ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, forward bool) (topoPos types.StreamPosition, err error)
|
SelectStreamToTopologicalPosition(ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, forward bool) (topoPos types.StreamPosition, err error)
|
||||||
}
|
}
|
||||||
|
|
@ -132,8 +130,6 @@ type BackwardsExtremities interface {
|
||||||
SelectBackwardExtremitiesForRoom(ctx context.Context, roomID string) (bwExtrems map[string][]string, err error)
|
SelectBackwardExtremitiesForRoom(ctx context.Context, roomID string) (bwExtrems map[string][]string, err error)
|
||||||
// DeleteBackwardExtremity removes a backwards extremity for a room, if one existed.
|
// DeleteBackwardExtremity removes a backwards extremity for a room, if one existed.
|
||||||
DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error)
|
DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error)
|
||||||
// DeleteBackwardExtremitiesFoorRoomID removes all backward extremities for a room. This should only be done when removing the room entirely.
|
|
||||||
DeleteBackwardExtremitiesForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendToDevice tracks send-to-device messages which are sent to individual
|
// SendToDevice tracks send-to-device messages which are sent to individual
|
||||||
|
|
@ -173,7 +169,6 @@ type Receipts interface {
|
||||||
|
|
||||||
type Memberships interface {
|
type Memberships interface {
|
||||||
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
|
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
|
||||||
SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string) (eventID string, streamPos, topologyPos types.StreamPosition, err error)
|
|
||||||
SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error)
|
SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
105
syncapi/storage/tables/output_room_events_test.go
Normal file
105
syncapi/storage/tables/output_room_events_test.go
Normal file
|
|
@ -0,0 +1,105 @@
|
||||||
|
package tables_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events, *sql.DB, func()) {
|
||||||
|
t.Helper()
|
||||||
|
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||||
|
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||||
|
ConnectionString: config.DataSource(connStr),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to open db: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tab tables.Events
|
||||||
|
switch dbType {
|
||||||
|
case test.DBTypePostgres:
|
||||||
|
tab, err = postgres.NewPostgresEventsTable(db)
|
||||||
|
case test.DBTypeSQLite:
|
||||||
|
var stream sqlite3.StreamIDStatements
|
||||||
|
if err = stream.Prepare(db); err != nil {
|
||||||
|
t.Fatalf("failed to prepare stream stmts: %s", err)
|
||||||
|
}
|
||||||
|
tab, err = sqlite3.NewSqliteEventsTable(db, &stream)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to make new table: %s", err)
|
||||||
|
}
|
||||||
|
return tab, db, close
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOutputRoomEventsTable(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
alice := test.NewUser()
|
||||||
|
room := test.NewRoom(t, alice)
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
tab, db, close := newOutputRoomEventsTable(t, dbType)
|
||||||
|
defer close()
|
||||||
|
events := room.Events()
|
||||||
|
err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
|
||||||
|
for _, ev := range events {
|
||||||
|
_, err := tab.InsertEvent(ctx, txn, ev, nil, nil, nil, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to InsertEvent: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// order = 2,0,3,1
|
||||||
|
wantEventIDs := []string{
|
||||||
|
events[2].EventID(), events[0].EventID(), events[3].EventID(), events[1].EventID(),
|
||||||
|
}
|
||||||
|
gotEvents, err := tab.SelectEvents(ctx, txn, wantEventIDs, nil, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to SelectEvents: %s", err)
|
||||||
|
}
|
||||||
|
gotEventIDs := make([]string, len(gotEvents))
|
||||||
|
for i := range gotEvents {
|
||||||
|
gotEventIDs[i] = gotEvents[i].EventID()
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(gotEventIDs, wantEventIDs) {
|
||||||
|
return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that contains_url is correctly populated
|
||||||
|
urlEv := room.CreateEvent(t, alice, "m.text", map[string]interface{}{
|
||||||
|
"body": "test.txt",
|
||||||
|
"url": "mxc://test.txt",
|
||||||
|
})
|
||||||
|
if _, err = tab.InsertEvent(ctx, txn, urlEv, nil, nil, nil, false); err != nil {
|
||||||
|
return fmt.Errorf("failed to InsertEvent: %s", err)
|
||||||
|
}
|
||||||
|
wantEventID := []string{urlEv.EventID()}
|
||||||
|
t := true
|
||||||
|
gotEvents, err = tab.SelectEvents(ctx, txn, wantEventID, &gomatrixserverlib.RoomEventFilter{Limit: 1, ContainsURL: &t}, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to SelectEvents: %s", err)
|
||||||
|
}
|
||||||
|
gotEventIDs = make([]string, len(gotEvents))
|
||||||
|
for i := range gotEvents {
|
||||||
|
gotEventIDs[i] = gotEvents[i].EventID()
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(gotEventIDs, wantEventID) {
|
||||||
|
return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
91
syncapi/storage/tables/topology_test.go
Normal file
91
syncapi/storage/tables/topology_test.go
Normal file
|
|
@ -0,0 +1,91 @@
|
||||||
|
package tables_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTopologyTable(t *testing.T, dbType test.DBType) (tables.Topology, *sql.DB, func()) {
|
||||||
|
t.Helper()
|
||||||
|
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||||
|
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||||
|
ConnectionString: config.DataSource(connStr),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to open db: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tab tables.Topology
|
||||||
|
switch dbType {
|
||||||
|
case test.DBTypePostgres:
|
||||||
|
tab, err = postgres.NewPostgresTopologyTable(db)
|
||||||
|
case test.DBTypeSQLite:
|
||||||
|
tab, err = sqlite3.NewSqliteTopologyTable(db)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to make new table: %s", err)
|
||||||
|
}
|
||||||
|
return tab, db, close
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTopologyTable(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
alice := test.NewUser()
|
||||||
|
room := test.NewRoom(t, alice)
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
tab, db, close := newTopologyTable(t, dbType)
|
||||||
|
defer close()
|
||||||
|
events := room.Events()
|
||||||
|
err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
|
||||||
|
var highestPos types.StreamPosition
|
||||||
|
for i, ev := range events {
|
||||||
|
topoPos, err := tab.InsertEventInTopology(ctx, txn, ev, types.StreamPosition(i))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to InsertEventInTopology: %s", err)
|
||||||
|
}
|
||||||
|
// topo pos = depth, depth starts at 1, hence 1+i
|
||||||
|
if topoPos != types.StreamPosition(1+i) {
|
||||||
|
return fmt.Errorf("got topo pos %d want %d", topoPos, 1+i)
|
||||||
|
}
|
||||||
|
highestPos = topoPos + 1
|
||||||
|
}
|
||||||
|
// check ordering works without limit
|
||||||
|
eventIDs, err := tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
|
||||||
|
}
|
||||||
|
test.AssertEventIDsEqual(t, eventIDs, events[:])
|
||||||
|
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
|
||||||
|
}
|
||||||
|
test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[:]))
|
||||||
|
// check ordering works with limit
|
||||||
|
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
|
||||||
|
}
|
||||||
|
test.AssertEventIDsEqual(t, eventIDs, events[:3])
|
||||||
|
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
|
||||||
|
}
|
||||||
|
test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[len(events)-3:]))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
@ -3,9 +3,11 @@ package streams
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
@ -26,7 +28,8 @@ type PDUStreamProvider struct {
|
||||||
|
|
||||||
tasks chan func()
|
tasks chan func()
|
||||||
workers atomic.Int32
|
workers atomic.Int32
|
||||||
userAPI userapi.UserInternalAPI
|
// userID+deviceID -> lazy loading cache
|
||||||
|
lazyLoadCache *caching.LazyLoadCache
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PDUStreamProvider) worker() {
|
func (p *PDUStreamProvider) worker() {
|
||||||
|
|
@ -188,7 +191,7 @@ func (p *PDUStreamProvider) IncrementalSync(
|
||||||
newPos = from
|
newPos = from
|
||||||
for _, delta := range stateDeltas {
|
for _, delta := range stateDeltas {
|
||||||
var pos types.StreamPosition
|
var pos types.StreamPosition
|
||||||
if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, r, delta, &eventFilter, req.Response); err != nil {
|
if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, r, delta, &eventFilter, &stateFilter, req.Response); err != nil {
|
||||||
req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed")
|
req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed")
|
||||||
return to
|
return to
|
||||||
}
|
}
|
||||||
|
|
@ -203,12 +206,14 @@ func (p *PDUStreamProvider) IncrementalSync(
|
||||||
return newPos
|
return newPos
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nolint:gocyclo
|
||||||
func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
device *userapi.Device,
|
device *userapi.Device,
|
||||||
r types.Range,
|
r types.Range,
|
||||||
delta types.StateDelta,
|
delta types.StateDelta,
|
||||||
eventFilter *gomatrixserverlib.RoomEventFilter,
|
eventFilter *gomatrixserverlib.RoomEventFilter,
|
||||||
|
stateFilter *gomatrixserverlib.StateFilter,
|
||||||
res *types.Response,
|
res *types.Response,
|
||||||
) (types.StreamPosition, error) {
|
) (types.StreamPosition, error) {
|
||||||
if delta.MembershipPos > 0 && delta.Membership == gomatrixserverlib.Leave {
|
if delta.MembershipPos > 0 && delta.Membership == gomatrixserverlib.Leave {
|
||||||
|
|
@ -225,13 +230,16 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
||||||
eventFilter, true, true,
|
eventFilter, true, true,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return r.From, err
|
if err == sql.ErrNoRows {
|
||||||
|
return r.To, nil
|
||||||
|
}
|
||||||
|
return r.From, fmt.Errorf("p.DB.RecentEvents: %w", err)
|
||||||
}
|
}
|
||||||
recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents)
|
recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents)
|
||||||
delta.StateEvents = removeDuplicates(delta.StateEvents, recentEvents) // roll back
|
delta.StateEvents = removeDuplicates(delta.StateEvents, recentEvents) // roll back
|
||||||
prevBatch, err := p.DB.GetBackwardTopologyPos(ctx, recentStreamEvents)
|
prevBatch, err := p.DB.GetBackwardTopologyPos(ctx, recentStreamEvents)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return r.From, err
|
return r.From, fmt.Errorf("p.DB.GetBackwardTopologyPos: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we didn't return any events at all then don't bother doing anything else.
|
// If we didn't return any events at all then don't bother doing anything else.
|
||||||
|
|
@ -247,7 +255,8 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
||||||
// room that were returned.
|
// room that were returned.
|
||||||
latestPosition := r.To
|
latestPosition := r.To
|
||||||
updateLatestPosition := func(mostRecentEventID string) {
|
updateLatestPosition := func(mostRecentEventID string) {
|
||||||
if _, pos, err := p.DB.PositionInTopology(ctx, mostRecentEventID); err == nil {
|
var pos types.StreamPosition
|
||||||
|
if _, pos, err = p.DB.PositionInTopology(ctx, mostRecentEventID); err == nil {
|
||||||
switch {
|
switch {
|
||||||
case r.Backwards && pos > latestPosition:
|
case r.Backwards && pos > latestPosition:
|
||||||
fallthrough
|
fallthrough
|
||||||
|
|
@ -263,6 +272,16 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
||||||
updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID())
|
updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if stateFilter.LazyLoadMembers {
|
||||||
|
delta.StateEvents, err = p.lazyLoadMembers(
|
||||||
|
ctx, delta.RoomID, true, limited, stateFilter.IncludeRedundantMembers,
|
||||||
|
device, recentEvents, delta.StateEvents,
|
||||||
|
)
|
||||||
|
if err != nil && err != sql.ErrNoRows {
|
||||||
|
return r.From, fmt.Errorf("p.lazyLoadMembers: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
hasMembershipChange := false
|
hasMembershipChange := false
|
||||||
for _, recentEvent := range recentStreamEvents {
|
for _, recentEvent := range recentStreamEvents {
|
||||||
if recentEvent.Type() == gomatrixserverlib.MRoomMember && recentEvent.StateKey() != nil {
|
if recentEvent.Type() == gomatrixserverlib.MRoomMember && recentEvent.StateKey() != nil {
|
||||||
|
|
@ -322,12 +341,16 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
||||||
wantFullState bool,
|
wantFullState bool,
|
||||||
device *userapi.Device,
|
device *userapi.Device,
|
||||||
) (jr *types.JoinResponse, err error) {
|
) (jr *types.JoinResponse, err error) {
|
||||||
|
jr = types.NewJoinResponse()
|
||||||
// TODO: When filters are added, we may need to call this multiple times to get enough events.
|
// TODO: When filters are added, we may need to call this multiple times to get enough events.
|
||||||
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
|
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
|
||||||
recentStreamEvents, limited, err := p.DB.RecentEvents(
|
recentStreamEvents, limited, err := p.DB.RecentEvents(
|
||||||
ctx, roomID, r, eventFilter, true, true,
|
ctx, roomID, r, eventFilter, true, true,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return jr, nil
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -402,7 +425,20 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
||||||
// "Can sync a room with a message with a transaction id" - which does a complete sync to check.
|
// "Can sync a room with a message with a transaction id" - which does a complete sync to check.
|
||||||
recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents)
|
recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents)
|
||||||
stateEvents = removeDuplicates(stateEvents, recentEvents)
|
stateEvents = removeDuplicates(stateEvents, recentEvents)
|
||||||
jr = types.NewJoinResponse()
|
|
||||||
|
if stateFilter.LazyLoadMembers {
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
stateEvents, err = p.lazyLoadMembers(ctx, roomID,
|
||||||
|
false, limited, stateFilter.IncludeRedundantMembers,
|
||||||
|
device, recentEvents, stateEvents,
|
||||||
|
)
|
||||||
|
if err != nil && err != sql.ErrNoRows {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
jr.Summary.JoinedMemberCount = &joinedCount
|
jr.Summary.JoinedMemberCount = &joinedCount
|
||||||
jr.Summary.InvitedMemberCount = &invitedCount
|
jr.Summary.InvitedMemberCount = &invitedCount
|
||||||
jr.Timeline.PrevBatch = prevBatch
|
jr.Timeline.PrevBatch = prevBatch
|
||||||
|
|
@ -412,6 +448,69 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
||||||
return jr, nil
|
return jr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *PDUStreamProvider) lazyLoadMembers(
|
||||||
|
ctx context.Context, roomID string,
|
||||||
|
incremental, limited, includeRedundant bool,
|
||||||
|
device *userapi.Device,
|
||||||
|
timelineEvents, stateEvents []*gomatrixserverlib.HeaderedEvent,
|
||||||
|
) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||||
|
if len(timelineEvents) == 0 {
|
||||||
|
return stateEvents, nil
|
||||||
|
}
|
||||||
|
// Work out which memberships to include
|
||||||
|
timelineUsers := make(map[string]struct{})
|
||||||
|
if !incremental {
|
||||||
|
timelineUsers[device.UserID] = struct{}{}
|
||||||
|
}
|
||||||
|
// Add all users the client doesn't know about yet to a list
|
||||||
|
for _, event := range timelineEvents {
|
||||||
|
// Membership is not yet cached, add it to the list
|
||||||
|
if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, event.Sender()); !ok {
|
||||||
|
timelineUsers[event.Sender()] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Preallocate with the same amount, even if it will end up with fewer values
|
||||||
|
newStateEvents := make([]*gomatrixserverlib.HeaderedEvent, 0, len(stateEvents))
|
||||||
|
// Remove existing membership events we don't care about, e.g. users not in the timeline.events
|
||||||
|
for _, event := range stateEvents {
|
||||||
|
if event.Type() == gomatrixserverlib.MRoomMember && event.StateKey() != nil {
|
||||||
|
// If this is a gapped incremental sync, we still want this membership
|
||||||
|
isGappedIncremental := limited && incremental
|
||||||
|
// We want this users membership event, keep it in the list
|
||||||
|
_, ok := timelineUsers[event.Sender()]
|
||||||
|
wantMembership := ok || isGappedIncremental
|
||||||
|
if wantMembership {
|
||||||
|
newStateEvents = append(newStateEvents, event)
|
||||||
|
if !includeRedundant {
|
||||||
|
p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, event.Sender(), event.EventID())
|
||||||
|
}
|
||||||
|
delete(timelineUsers, event.Sender())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
newStateEvents = append(newStateEvents, event)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
wantUsers := make([]string, 0, len(timelineUsers))
|
||||||
|
for userID := range timelineUsers {
|
||||||
|
wantUsers = append(wantUsers, userID)
|
||||||
|
}
|
||||||
|
// Query missing membership events
|
||||||
|
memberships, err := p.DB.GetStateEventsForRoom(ctx, roomID, &gomatrixserverlib.StateFilter{
|
||||||
|
Limit: 100,
|
||||||
|
Senders: &wantUsers,
|
||||||
|
Types: &[]string{gomatrixserverlib.MRoomMember},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return stateEvents, err
|
||||||
|
}
|
||||||
|
// cache the membership events
|
||||||
|
for _, membership := range memberships {
|
||||||
|
p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, membership.Sender(), membership.EventID())
|
||||||
|
}
|
||||||
|
stateEvents = append(newStateEvents, memberships...)
|
||||||
|
return stateEvents, nil
|
||||||
|
}
|
||||||
|
|
||||||
// addIgnoredUsersToFilter adds ignored users to the eventfilter and
|
// addIgnoredUsersToFilter adds ignored users to the eventfilter and
|
||||||
// the syncreq itself for further use in streams.
|
// the syncreq itself for further use in streams.
|
||||||
func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error {
|
func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error {
|
||||||
|
|
@ -423,8 +522,12 @@ func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *ty
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
req.IgnoredUsers = *ignores
|
req.IgnoredUsers = *ignores
|
||||||
|
userList := make([]string, 0, len(ignores.List))
|
||||||
for userID := range ignores.List {
|
for userID := range ignores.List {
|
||||||
eventFilter.NotSenders = append(eventFilter.NotSenders, userID)
|
userList = append(userList, userID)
|
||||||
|
}
|
||||||
|
if len(userList) > 0 {
|
||||||
|
eventFilter.NotSenders = &userList
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -27,12 +27,12 @@ type Streams struct {
|
||||||
func NewSyncStreamProviders(
|
func NewSyncStreamProviders(
|
||||||
d storage.Database, userAPI userapi.UserInternalAPI,
|
d storage.Database, userAPI userapi.UserInternalAPI,
|
||||||
rsAPI rsapi.RoomserverInternalAPI, keyAPI keyapi.KeyInternalAPI,
|
rsAPI rsapi.RoomserverInternalAPI, keyAPI keyapi.KeyInternalAPI,
|
||||||
eduCache *caching.EDUCache, notifier *notifier.Notifier,
|
eduCache *caching.EDUCache, lazyLoadCache *caching.LazyLoadCache, notifier *notifier.Notifier,
|
||||||
) *Streams {
|
) *Streams {
|
||||||
streams := &Streams{
|
streams := &Streams{
|
||||||
PDUStreamProvider: &PDUStreamProvider{
|
PDUStreamProvider: &PDUStreamProvider{
|
||||||
StreamProvider: StreamProvider{DB: d},
|
StreamProvider: StreamProvider{DB: d},
|
||||||
userAPI: userAPI,
|
lazyLoadCache: lazyLoadCache,
|
||||||
},
|
},
|
||||||
TypingStreamProvider: &TypingStreamProvider{
|
TypingStreamProvider: &TypingStreamProvider{
|
||||||
StreamProvider: StreamProvider{DB: d},
|
StreamProvider: StreamProvider{DB: d},
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@
|
||||||
package sync
|
package sync
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
@ -60,10 +61,10 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||||
return nil, fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
|
return nil, fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
|
||||||
}
|
}
|
||||||
if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err != nil {
|
if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err != nil && err != sql.ErrNoRows {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("syncDB.GetFilter failed")
|
util.GetLogger(req.Context()).WithError(err).Error("syncDB.GetFilter failed")
|
||||||
return nil, fmt.Errorf("syncDB.GetFilter: %w", err)
|
return nil, fmt.Errorf("syncDB.GetFilter: %w", err)
|
||||||
} else {
|
} else if f != nil {
|
||||||
filter = *f
|
filter = *f
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -57,8 +57,12 @@ func AddPublicRoutes(
|
||||||
}
|
}
|
||||||
|
|
||||||
eduCache := caching.NewTypingCache()
|
eduCache := caching.NewTypingCache()
|
||||||
|
lazyLoadCache, err := caching.NewLazyLoadCache()
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Panicf("failed to create lazy loading cache")
|
||||||
|
}
|
||||||
notifier := notifier.NewNotifier()
|
notifier := notifier.NewNotifier()
|
||||||
streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, notifier)
|
streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, lazyLoadCache, notifier)
|
||||||
notifier.SetCurrentPosition(streams.Latest(context.Background()))
|
notifier.SetCurrentPosition(streams.Latest(context.Background()))
|
||||||
if err = notifier.Load(context.Background(), syncDB); err != nil {
|
if err = notifier.Load(context.Background(), syncDB); err != nil {
|
||||||
logrus.WithError(err).Panicf("failed to load notifier ")
|
logrus.WithError(err).Panicf("failed to load notifier ")
|
||||||
|
|
|
||||||
|
|
@ -312,10 +312,10 @@ Inbound federation can return events
|
||||||
Inbound federation can return missing events for world_readable visibility
|
Inbound federation can return missing events for world_readable visibility
|
||||||
Inbound federation can return missing events for invite visibility
|
Inbound federation can return missing events for invite visibility
|
||||||
Inbound federation can get public room list
|
Inbound federation can get public room list
|
||||||
POST /rooms/:room_id/redact/:event_id as power user redacts message
|
PUT /rooms/:room_id/redact/:event_id/:txn_id as power user redacts message
|
||||||
POST /rooms/:room_id/redact/:event_id as original message sender redacts message
|
PUT /rooms/:room_id/redact/:event_id/:txn_id as original message sender redacts message
|
||||||
POST /rooms/:room_id/redact/:event_id as random user does not redact message
|
PUT /rooms/:room_id/redact/:event_id/:txn_id as random user does not redact message
|
||||||
POST /redact disallows redaction of event in different room
|
PUT /redact disallows redaction of event in different room
|
||||||
An event which redacts itself should be ignored
|
An event which redacts itself should be ignored
|
||||||
A pair of events which redact each other should be ignored
|
A pair of events which redact each other should be ignored
|
||||||
Redaction of a redaction redacts the redaction reason
|
Redaction of a redaction redacts the redaction reason
|
||||||
|
|
@ -696,4 +696,17 @@ Room state after a rejected message event is the same as before
|
||||||
Room state after a rejected state event is the same as before
|
Room state after a rejected state event is the same as before
|
||||||
Ignore user in existing room
|
Ignore user in existing room
|
||||||
Ignore invite in full sync
|
Ignore invite in full sync
|
||||||
Ignore invite in incremental sync
|
Ignore invite in incremental sync
|
||||||
|
A filtered timeline reaches its limit
|
||||||
|
A change to displayname should not result in a full state sync
|
||||||
|
Can fetch images in room
|
||||||
|
The only membership state included in an initial sync is for all the senders in the timeline
|
||||||
|
The only membership state included in an incremental sync is for senders in the timeline
|
||||||
|
Old members are included in gappy incr LL sync if they start speaking
|
||||||
|
We do send redundant membership state across incremental syncs if asked
|
||||||
|
Rejecting invite over federation doesn't break incremental /sync
|
||||||
|
Gapped incremental syncs include all state changes
|
||||||
|
Old leaves are present in gapped incremental syncs
|
||||||
|
Leaves are present in non-gapped incremental syncs
|
||||||
|
Members from the gap are included in gappy incr LL sync
|
||||||
|
Presence can be set from sync
|
||||||
59
test/db.go
59
test/db.go
|
|
@ -15,12 +15,16 @@
|
||||||
package test
|
package test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"os/user"
|
"os/user"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lib/pq"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DBType int
|
type DBType int
|
||||||
|
|
@ -30,7 +34,7 @@ var DBTypePostgres DBType = 2
|
||||||
|
|
||||||
var Quiet = false
|
var Quiet = false
|
||||||
|
|
||||||
func createLocalDB(dbName string) string {
|
func createLocalDB(dbName string) {
|
||||||
if !Quiet {
|
if !Quiet {
|
||||||
fmt.Println("Note: tests require a postgres install accessible to the current user")
|
fmt.Println("Note: tests require a postgres install accessible to the current user")
|
||||||
}
|
}
|
||||||
|
|
@ -43,7 +47,29 @@ func createLocalDB(dbName string) string {
|
||||||
if err != nil && !Quiet {
|
if err != nil && !Quiet {
|
||||||
fmt.Println("createLocalDB returned error:", err)
|
fmt.Println("createLocalDB returned error:", err)
|
||||||
}
|
}
|
||||||
return dbName
|
}
|
||||||
|
|
||||||
|
func createRemoteDB(t *testing.T, dbName, user, connStr string) {
|
||||||
|
db, err := sql.Open("postgres", connStr+" dbname=postgres")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to open postgres conn with connstr=%s : %s", connStr, err)
|
||||||
|
}
|
||||||
|
_, err = db.Exec(fmt.Sprintf(`CREATE DATABASE %s;`, dbName))
|
||||||
|
if err != nil {
|
||||||
|
pqErr, ok := err.(*pq.Error)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("failed to CREATE DATABASE: %s", err)
|
||||||
|
}
|
||||||
|
// we ignore duplicate database error as we expect this
|
||||||
|
if pqErr.Code != "42P04" {
|
||||||
|
t.Fatalf("failed to CREATE DATABASE with code=%s msg=%s", pqErr.Code, pqErr.Message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, err = db.Exec(fmt.Sprintf(`GRANT ALL PRIVILEGES ON DATABASE %s TO %s`, dbName, user))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to GRANT: %s", err)
|
||||||
|
}
|
||||||
|
_ = db.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func currentUser() string {
|
func currentUser() string {
|
||||||
|
|
@ -64,6 +90,7 @@ func currentUser() string {
|
||||||
// TODO: namespace for concurrent package tests
|
// TODO: namespace for concurrent package tests
|
||||||
func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, close func()) {
|
func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, close func()) {
|
||||||
if dbType == DBTypeSQLite {
|
if dbType == DBTypeSQLite {
|
||||||
|
// this will be made in the current working directory which namespaces concurrent package runs correctly
|
||||||
dbname := "dendrite_test.db"
|
dbname := "dendrite_test.db"
|
||||||
return fmt.Sprintf("file:%s", dbname), func() {
|
return fmt.Sprintf("file:%s", dbname), func() {
|
||||||
err := os.Remove(dbname)
|
err := os.Remove(dbname)
|
||||||
|
|
@ -79,13 +106,9 @@ func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, clo
|
||||||
if user == "" {
|
if user == "" {
|
||||||
user = currentUser()
|
user = currentUser()
|
||||||
}
|
}
|
||||||
dbName := os.Getenv("POSTGRES_DB")
|
|
||||||
if dbName == "" {
|
|
||||||
dbName = createLocalDB("dendrite_test")
|
|
||||||
}
|
|
||||||
connStr = fmt.Sprintf(
|
connStr = fmt.Sprintf(
|
||||||
"user=%s dbname=%s sslmode=disable",
|
"user=%s sslmode=disable",
|
||||||
user, dbName,
|
user,
|
||||||
)
|
)
|
||||||
// optional vars, used in CI
|
// optional vars, used in CI
|
||||||
password := os.Getenv("POSTGRES_PASSWORD")
|
password := os.Getenv("POSTGRES_PASSWORD")
|
||||||
|
|
@ -97,6 +120,25 @@ func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, clo
|
||||||
connStr += fmt.Sprintf(" host=%s", host)
|
connStr += fmt.Sprintf(" host=%s", host)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// superuser database
|
||||||
|
postgresDB := os.Getenv("POSTGRES_DB")
|
||||||
|
// we cannot use 'dendrite_test' here else 2x concurrently running packages will try to use the same db.
|
||||||
|
// instead, hash the current working directory, snaffle the first 16 bytes and append that to dendrite_test
|
||||||
|
// and use that as the unique db name. We do this because packages are per-directory hence by hashing the
|
||||||
|
// working (test) directory we ensure we get a consistent hash and don't hash against concurrent packages.
|
||||||
|
wd, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("cannot get working directory: %s", err)
|
||||||
|
}
|
||||||
|
hash := sha256.Sum256([]byte(wd))
|
||||||
|
dbName := fmt.Sprintf("dendrite_test_%s", hex.EncodeToString(hash[:16]))
|
||||||
|
if postgresDB == "" { // local server, use createdb
|
||||||
|
createLocalDB(dbName)
|
||||||
|
} else { // remote server, shell into the postgres user and CREATE DATABASE
|
||||||
|
createRemoteDB(t, dbName, user, connStr)
|
||||||
|
}
|
||||||
|
connStr += fmt.Sprintf(" dbname=%s", dbName)
|
||||||
|
|
||||||
return connStr, func() {
|
return connStr, func() {
|
||||||
// Drop all tables on the database to get a fresh instance
|
// Drop all tables on the database to get a fresh instance
|
||||||
db, err := sql.Open("postgres", connStr)
|
db, err := sql.Open("postgres", connStr)
|
||||||
|
|
@ -121,6 +163,7 @@ func WithAllDatabases(t *testing.T, testFn func(t *testing.T, db DBType)) {
|
||||||
for dbName, dbType := range dbs {
|
for dbName, dbType := range dbs {
|
||||||
dbt := dbType
|
dbt := dbType
|
||||||
t.Run(dbName, func(tt *testing.T) {
|
t.Run(dbName, func(tt *testing.T) {
|
||||||
|
tt.Parallel()
|
||||||
testFn(tt, dbt)
|
testFn(tt, dbt)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,9 @@
|
||||||
package test
|
package test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
@ -49,3 +51,40 @@ func WithUnsigned(unsigned interface{}) eventModifier {
|
||||||
e.unsigned = unsigned
|
e.unsigned = unsigned
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reverse a list of events
|
||||||
|
func Reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
|
||||||
|
out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
|
||||||
|
for i := 0; i < len(in); i++ {
|
||||||
|
out[i] = in[len(in)-i-1]
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func AssertEventIDsEqual(t *testing.T, gotEventIDs []string, wants []*gomatrixserverlib.HeaderedEvent) {
|
||||||
|
t.Helper()
|
||||||
|
if len(gotEventIDs) != len(wants) {
|
||||||
|
t.Fatalf("length mismatch: got %d events, want %d", len(gotEventIDs), len(wants))
|
||||||
|
}
|
||||||
|
for i := range wants {
|
||||||
|
w := wants[i].EventID()
|
||||||
|
g := gotEventIDs[i]
|
||||||
|
if w != g {
|
||||||
|
t.Errorf("event at index %d mismatch:\ngot %s\n\nwant %s", i, string(g), string(w))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func AssertEventsEqual(t *testing.T, gots, wants []*gomatrixserverlib.HeaderedEvent) {
|
||||||
|
t.Helper()
|
||||||
|
if len(gots) != len(wants) {
|
||||||
|
t.Fatalf("length mismatch: got %d events, want %d", len(gots), len(wants))
|
||||||
|
}
|
||||||
|
for i := range wants {
|
||||||
|
w := wants[i].JSON()
|
||||||
|
g := gots[i].JSON()
|
||||||
|
if !bytes.Equal(w, g) {
|
||||||
|
t.Errorf("event at index %d mismatch:\ngot %s\n\nwant %s", i, string(g), string(w))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -492,16 +492,16 @@ type PerformPusherDeletionRequest struct {
|
||||||
|
|
||||||
// Pusher represents a push notification subscriber
|
// Pusher represents a push notification subscriber
|
||||||
type Pusher struct {
|
type Pusher struct {
|
||||||
SessionID int64 `json:"session_id,omitempty"`
|
SessionID int64 `json:"session_id,omitempty"`
|
||||||
PushKey string `json:"pushkey"`
|
PushKey string `json:"pushkey"`
|
||||||
PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"`
|
PushKeyTS int64 `json:"pushkey_ts,omitempty"`
|
||||||
Kind PusherKind `json:"kind"`
|
Kind PusherKind `json:"kind"`
|
||||||
AppID string `json:"app_id"`
|
AppID string `json:"app_id"`
|
||||||
AppDisplayName string `json:"app_display_name"`
|
AppDisplayName string `json:"app_display_name"`
|
||||||
DeviceDisplayName string `json:"device_display_name"`
|
DeviceDisplayName string `json:"device_display_name"`
|
||||||
ProfileTag string `json:"profile_tag"`
|
ProfileTag string `json:"profile_tag"`
|
||||||
Language string `json:"lang"`
|
Language string `json:"lang"`
|
||||||
Data map[string]interface{} `json:"data"`
|
Data map[string]interface{} `json:"data"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type PusherKind string
|
type PusherKind string
|
||||||
|
|
|
||||||
|
|
@ -653,7 +653,7 @@ func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.Perform
|
||||||
return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart)
|
return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart)
|
||||||
}
|
}
|
||||||
if req.Pusher.PushKeyTS == 0 {
|
if req.Pusher.PushKeyTS == 0 {
|
||||||
req.Pusher.PushKeyTS = gomatrixserverlib.AsTimestamp(time.Now())
|
req.Pusher.PushKeyTS = int64(time.Now().Unix())
|
||||||
}
|
}
|
||||||
return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart)
|
return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,6 @@ import (
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -95,7 +94,7 @@ type pushersStatements struct {
|
||||||
// Returns nil error success.
|
// Returns nil error success.
|
||||||
func (s *pushersStatements) InsertPusher(
|
func (s *pushersStatements) InsertPusher(
|
||||||
ctx context.Context, txn *sql.Tx, session_id int64,
|
ctx context.Context, txn *sql.Tx, session_id int64,
|
||||||
pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
|
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
|
||||||
) error {
|
) error {
|
||||||
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
||||||
logrus.Debugf("Created pusher %d", session_id)
|
logrus.Debugf("Created pusher %d", session_id)
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,6 @@ import (
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -95,7 +94,7 @@ type pushersStatements struct {
|
||||||
// Returns nil error success.
|
// Returns nil error success.
|
||||||
func (s *pushersStatements) InsertPusher(
|
func (s *pushersStatements) InsertPusher(
|
||||||
ctx context.Context, txn *sql.Tx, session_id int64,
|
ctx context.Context, txn *sql.Tx, session_id int64,
|
||||||
pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
|
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
|
||||||
) error {
|
) error {
|
||||||
_, err := s.insertPusherStmt.ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
_, err := s.insertPusherStmt.ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
||||||
logrus.Debugf("Created pusher %d", session_id)
|
logrus.Debugf("Created pusher %d", session_id)
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,6 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type AccountDataTable interface {
|
type AccountDataTable interface {
|
||||||
|
|
@ -96,7 +95,7 @@ type ThreePIDTable interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type PusherTable interface {
|
type PusherTable interface {
|
||||||
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error
|
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error
|
||||||
SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error)
|
SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error)
|
||||||
DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error
|
DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error
|
||||||
DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error
|
DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue