mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-01 03:03:10 -06:00
Merge branch 'main' into neilalexander/removelibp2p
This commit is contained in:
commit
786b89dcd5
|
|
@ -52,6 +52,7 @@ import (
|
|||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
|
||||
pineconeConnections "github.com/matrix-org/pinecone/connections"
|
||||
pineconeMulticast "github.com/matrix-org/pinecone/multicast"
|
||||
pineconeRouter "github.com/matrix-org/pinecone/router"
|
||||
pineconeSessions "github.com/matrix-org/pinecone/sessions"
|
||||
|
|
@ -71,11 +72,9 @@ type DendriteMonolith struct {
|
|||
PineconeRouter *pineconeRouter.Router
|
||||
PineconeMulticast *pineconeMulticast.Multicast
|
||||
PineconeQUIC *pineconeSessions.Sessions
|
||||
PineconeManager *pineconeConnections.ConnectionManager
|
||||
StorageDirectory string
|
||||
CacheDirectory string
|
||||
staticPeerURI string
|
||||
staticPeerMutex sync.RWMutex
|
||||
staticPeerAttempt chan struct{}
|
||||
listener net.Listener
|
||||
httpServer *http.Server
|
||||
processContext *process.ProcessContext
|
||||
|
|
@ -104,15 +103,8 @@ func (m *DendriteMonolith) SetMulticastEnabled(enabled bool) {
|
|||
}
|
||||
|
||||
func (m *DendriteMonolith) SetStaticPeer(uri string) {
|
||||
m.staticPeerMutex.Lock()
|
||||
m.staticPeerURI = strings.TrimSpace(uri)
|
||||
m.staticPeerMutex.Unlock()
|
||||
m.DisconnectType(int(pineconeRouter.PeerTypeRemote))
|
||||
if uri != "" {
|
||||
go func() {
|
||||
m.staticPeerAttempt <- struct{}{}
|
||||
}()
|
||||
}
|
||||
m.PineconeManager.RemovePeers()
|
||||
m.PineconeManager.AddPeer(strings.TrimSpace(uri))
|
||||
}
|
||||
|
||||
func (m *DendriteMonolith) DisconnectType(peertype int) {
|
||||
|
|
@ -210,43 +202,6 @@ func (m *DendriteMonolith) RegisterDevice(localpart, deviceID string) (string, e
|
|||
return loginRes.Device.AccessToken, nil
|
||||
}
|
||||
|
||||
func (m *DendriteMonolith) staticPeerConnect() {
|
||||
connected := map[string]bool{} // URI -> connected?
|
||||
attempt := func() {
|
||||
m.staticPeerMutex.RLock()
|
||||
uri := m.staticPeerURI
|
||||
m.staticPeerMutex.RUnlock()
|
||||
if uri == "" {
|
||||
return
|
||||
}
|
||||
for k := range connected {
|
||||
delete(connected, k)
|
||||
}
|
||||
for _, uri := range strings.Split(uri, ",") {
|
||||
connected[strings.TrimSpace(uri)] = false
|
||||
}
|
||||
for _, info := range m.PineconeRouter.Peers() {
|
||||
connected[info.URI] = true
|
||||
}
|
||||
for k, online := range connected {
|
||||
if !online {
|
||||
if err := conn.ConnectToPeer(m.PineconeRouter, k); err != nil {
|
||||
logrus.WithError(err).Error("Failed to connect to static peer")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-m.processContext.Context().Done():
|
||||
case <-m.staticPeerAttempt:
|
||||
attempt()
|
||||
case <-time.After(time.Second * 5):
|
||||
attempt()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// nolint:gocyclo
|
||||
func (m *DendriteMonolith) Start() {
|
||||
var err error
|
||||
|
|
@ -284,6 +239,7 @@ func (m *DendriteMonolith) Start() {
|
|||
m.PineconeRouter = pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
|
||||
m.PineconeQUIC = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), m.PineconeRouter, []string{"matrix"})
|
||||
m.PineconeMulticast = pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), m.PineconeRouter)
|
||||
m.PineconeManager = pineconeConnections.NewConnectionManager(m.PineconeRouter)
|
||||
|
||||
prefix := hex.EncodeToString(pk)
|
||||
cfg := &config.Dendrite{}
|
||||
|
|
@ -392,9 +348,6 @@ func (m *DendriteMonolith) Start() {
|
|||
|
||||
m.processContext = base.ProcessContext
|
||||
|
||||
m.staticPeerAttempt = make(chan struct{}, 1)
|
||||
go m.staticPeerConnect()
|
||||
|
||||
go func() {
|
||||
m.logger.Info("Listening on ", cfg.Global.ServerName)
|
||||
m.logger.Fatal(m.httpServer.Serve(m.PineconeQUIC.Protocol("matrix")))
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ func RequestTurnServer(req *http.Request, device *api.Device, cfg *config.Client
|
|||
|
||||
if turnConfig.SharedSecret != "" {
|
||||
expiry := time.Now().Add(duration).Unix()
|
||||
resp.Username = fmt.Sprintf("%d:%s", expiry, device.UserID)
|
||||
mac := hmac.New(sha1.New, []byte(turnConfig.SharedSecret))
|
||||
_, err := mac.Write([]byte(resp.Username))
|
||||
|
||||
|
|
@ -60,7 +61,6 @@ func RequestTurnServer(req *http.Request, device *api.Device, cfg *config.Client
|
|||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
||||
resp.Username = fmt.Sprintf("%d:%s", expiry, device.UserID)
|
||||
resp.Password = base64.StdEncoding.EncodeToString(mac.Sum(nil))
|
||||
} else if turnConfig.Username != "" && turnConfig.Password != "" {
|
||||
resp.Username = turnConfig.Username
|
||||
|
|
|
|||
|
|
@ -25,7 +25,6 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
|
@ -47,6 +46,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/userapi"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
pineconeConnections "github.com/matrix-org/pinecone/connections"
|
||||
pineconeMulticast "github.com/matrix-org/pinecone/multicast"
|
||||
pineconeRouter "github.com/matrix-org/pinecone/router"
|
||||
pineconeSessions "github.com/matrix-org/pinecone/sessions"
|
||||
|
|
@ -90,6 +90,13 @@ func main() {
|
|||
}
|
||||
|
||||
pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
|
||||
pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
|
||||
pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter)
|
||||
pManager := pineconeConnections.NewConnectionManager(pRouter)
|
||||
pMulticast.Start()
|
||||
if instancePeer != nil && *instancePeer != "" {
|
||||
pManager.AddPeer(*instancePeer)
|
||||
}
|
||||
|
||||
go func() {
|
||||
listener, err := net.Listen("tcp", *instanceListen)
|
||||
|
|
@ -119,36 +126,6 @@ func main() {
|
|||
}
|
||||
}()
|
||||
|
||||
pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
|
||||
pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter)
|
||||
pMulticast.Start()
|
||||
|
||||
connectToStaticPeer := func() {
|
||||
connected := map[string]bool{} // URI -> connected?
|
||||
for _, uri := range strings.Split(*instancePeer, ",") {
|
||||
connected[strings.TrimSpace(uri)] = false
|
||||
}
|
||||
attempt := func() {
|
||||
for k := range connected {
|
||||
connected[k] = false
|
||||
}
|
||||
for _, info := range pRouter.Peers() {
|
||||
connected[info.URI] = true
|
||||
}
|
||||
for k, online := range connected {
|
||||
if !online {
|
||||
if err := conn.ConnectToPeer(pRouter, k); err != nil {
|
||||
logrus.WithError(err).Error("Failed to connect to static peer")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for {
|
||||
attempt()
|
||||
time.Sleep(time.Second * 5)
|
||||
}
|
||||
}
|
||||
|
||||
cfg := &config.Dendrite{}
|
||||
cfg.Defaults(true)
|
||||
cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk))
|
||||
|
|
@ -268,7 +245,6 @@ func main() {
|
|||
Handler: pMux,
|
||||
}
|
||||
|
||||
go connectToStaticPeer()
|
||||
go func() {
|
||||
pubkey := pRouter.PublicKey()
|
||||
logrus.Info("Listening on ", hex.EncodeToString(pubkey[:]))
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ import (
|
|||
"encoding/hex"
|
||||
"fmt"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/matrix-org/dendrite/appservice"
|
||||
|
|
@ -44,6 +43,7 @@ import (
|
|||
|
||||
_ "github.com/matrix-org/go-sqlite3-js"
|
||||
|
||||
pineconeConnections "github.com/matrix-org/pinecone/connections"
|
||||
pineconeRouter "github.com/matrix-org/pinecone/router"
|
||||
pineconeSessions "github.com/matrix-org/pinecone/sessions"
|
||||
)
|
||||
|
|
@ -154,6 +154,8 @@ func startup() {
|
|||
|
||||
pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
|
||||
pSessions := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
|
||||
pManager := pineconeConnections.NewConnectionManager(pRouter)
|
||||
pManager.AddPeer("wss://pinecone.matrix.org/public")
|
||||
|
||||
cfg := &config.Dendrite{}
|
||||
cfg.Defaults(true)
|
||||
|
|
@ -237,20 +239,4 @@ func startup() {
|
|||
}
|
||||
s.ListenAndServe("fetch")
|
||||
}()
|
||||
|
||||
// Connect to the static peer
|
||||
go func() {
|
||||
for {
|
||||
if pRouter.PeerCount(pineconeRouter.PeerTypeRemote) == 0 {
|
||||
if err := conn.ConnectToPeer(pRouter, publicPeer); err != nil {
|
||||
logrus.WithError(err).Error("Failed to connect to static peer")
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-base.ProcessContext.Context().Done():
|
||||
return
|
||||
case <-time.After(time.Second * 5):
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
|
|
|||
13
go.mod
13
go.mod
|
|
@ -1,8 +1,8 @@
|
|||
module github.com/matrix-org/dendrite
|
||||
|
||||
replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.7.5-0.20220311134712-e2e4a244f30e
|
||||
replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.8.1-0.20220419100629-2278c94774f9
|
||||
|
||||
replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c
|
||||
replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.13.1-0.20220419101051-b262d9f0be1e
|
||||
|
||||
require (
|
||||
github.com/Arceliar/ironwood v0.0.0-20211125050254-8951369625d0
|
||||
|
|
@ -27,15 +27,15 @@ require (
|
|||
github.com/juju/testing v0.0.0-20220203020004-a0ff61f03494 // indirect
|
||||
github.com/lib/pq v1.10.5
|
||||
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
|
||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d
|
||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
|
||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5
|
||||
github.com/matrix-org/pinecone v0.0.0-20220404141326-e526fa82f79d
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220408160933-cf558306b56f
|
||||
github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48
|
||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
||||
github.com/mattn/go-sqlite3 v1.14.10
|
||||
github.com/miekg/dns v1.1.31 // indirect
|
||||
github.com/nats-io/nats-server/v2 v2.7.4-0.20220309205833-773636c1c5bb
|
||||
github.com/nats-io/nats.go v1.13.1-0.20220308171302-2f2f6968e98d
|
||||
github.com/nats-io/nats.go v1.14.0
|
||||
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
||||
github.com/ngrok/sqlmw v0.0.0-20211220175533-9d16fdc47b31
|
||||
|
|
@ -56,6 +56,7 @@ require (
|
|||
golang.org/x/image v0.0.0-20220321031419-a8550c1d254a
|
||||
golang.org/x/mobile v0.0.0-20220407111146-e579adbbc4a2
|
||||
golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3
|
||||
golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 // indirect
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211
|
||||
gopkg.in/h2non/bimg.v1 v1.1.9
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
|
|
|
|||
27
go.sum
27
go.sum
|
|
@ -789,15 +789,15 @@ github.com/masterzen/winrm v0.0.0-20161014151040-7a535cd943fc/go.mod h1:CfZSN7zw
|
|||
github.com/masterzen/xmlpath v0.0.0-20140218185901-13f4951698ad/go.mod h1:A0zPC53iKKKcXYxr4ROjpQRQ5FgJXtelNdSmHHuq/tY=
|
||||
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e h1:DP5RC0Z3XdyBEW5dKt8YPeN6vZbm6OzVaGVp7f1BQRM=
|
||||
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e/go.mod h1:NgPCr+UavRGH6n5jmdX8DuqFZ4JiCWIJoZiuhTRLSUg=
|
||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d h1:mGhPVaTht5NViFN/UpdrIlRApmH2FWcVaKUH5MdBKiY=
|
||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
|
||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw2QV3YD/fRrzEDPNGgTlJlvXY0EHHnT87wF3OA=
|
||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
|
||||
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
|
||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
|
||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5 h1:Fkennny7+Z/5pygrhjFMZbz1j++P2hhhLoT7NO3p8DQ=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5/go.mod h1:V5eO8rn/C3rcxig37A/BCeKerLFS+9Avg/77FIeTZ48=
|
||||
github.com/matrix-org/pinecone v0.0.0-20220404141326-e526fa82f79d h1:1+T4eOPRsf6cr0lMPW4oO2k8TTHm4mqIh65kpEID5Rk=
|
||||
github.com/matrix-org/pinecone v0.0.0-20220404141326-e526fa82f79d/go.mod h1:ulJzsVOTssIVp1j/m5eI//4VpAGDkMt5NrRuAVX7wpc=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220408160933-cf558306b56f h1:MZrl4TgTnlaOn2Cu9gJCoJ3oyW5mT4/3QIZGgZXzKl4=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220408160933-cf558306b56f/go.mod h1:V5eO8rn/C3rcxig37A/BCeKerLFS+9Avg/77FIeTZ48=
|
||||
github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48 h1:W0sjjC6yjskHX4mb0nk3p0fXAlbU5bAFUFeEtlrPASE=
|
||||
github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48/go.mod h1:ulJzsVOTssIVp1j/m5eI//4VpAGDkMt5NrRuAVX7wpc=
|
||||
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
|
||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
||||
|
|
@ -878,8 +878,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8m
|
|||
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
||||
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
||||
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
|
||||
github.com/nats-io/jwt/v2 v2.2.1-0.20220113022732-58e87895b296 h1:vU9tpM3apjYlLLeY23zRWJ9Zktr5jp+mloR942LEOpY=
|
||||
github.com/nats-io/jwt/v2 v2.2.1-0.20220113022732-58e87895b296/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k=
|
||||
github.com/nats-io/jwt/v2 v2.2.1-0.20220330180145-442af02fd36a h1:lem6QCvxR0Y28gth9P+wV2K/zYUUAkJ+55U8cpS0p5I=
|
||||
github.com/nats-io/jwt/v2 v2.2.1-0.20220330180145-442af02fd36a/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k=
|
||||
github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8=
|
||||
github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4=
|
||||
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
|
||||
|
|
@ -888,10 +888,10 @@ github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uY
|
|||
github.com/ncw/swift v1.0.47/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM=
|
||||
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo=
|
||||
github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
|
||||
github.com/neilalexander/nats-server/v2 v2.7.5-0.20220311134712-e2e4a244f30e h1:5tEHLzvDeS6IeqO2o9FFhsE3V2erYj8FlMt2J91wzsk=
|
||||
github.com/neilalexander/nats-server/v2 v2.7.5-0.20220311134712-e2e4a244f30e/go.mod h1:1vZ2Nijh8tcyNe8BDVyTviCd9NYzRbubQYiEHsvOQWc=
|
||||
github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c h1:G2qsv7D0rY94HAu8pXmElMluuMHQ85waxIDQBhIzV2Q=
|
||||
github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w=
|
||||
github.com/neilalexander/nats-server/v2 v2.8.1-0.20220419100629-2278c94774f9 h1:VGU5HYAwy8LRbSkrT+kCHvujVmwK8Aa/vc1O+eReTbM=
|
||||
github.com/neilalexander/nats-server/v2 v2.8.1-0.20220419100629-2278c94774f9/go.mod h1:5vic7C58BFEVltiZhs7Kq81q2WcEPhJPsmNv1FOrdv0=
|
||||
github.com/neilalexander/nats.go v1.13.1-0.20220419101051-b262d9f0be1e h1:kNIzIzj2OvnlreA+sTJ12nWJzTP3OSLNKDL/Iq9mF6Y=
|
||||
github.com/neilalexander/nats.go v1.13.1-0.20220419101051-b262d9f0be1e/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w=
|
||||
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 h1:lrVQzBtkeQEGGYUHwSX1XPe1E5GL6U3KYCNe2G4bncQ=
|
||||
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
||||
|
|
@ -1541,8 +1541,9 @@ golang.org/x/sys v0.0.0-20211116061358-0a5406a5449c/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64 h1:D1v9ucDTYBtbz5vNuBbAhIMAGhQhJ6Ym5ah3maMVNX4=
|
||||
golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 h1:QyVthZKMsyaQwBTJE04jdNN0Pp5Fn9Qga0mrgxyERQM=
|
||||
golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
|
|
|
|||
86
internal/caching/cache_lazy_load_members.go
Normal file
86
internal/caching/cache_lazy_load_members.go
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
package caching
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
)
|
||||
|
||||
const (
|
||||
LazyLoadCacheName = "lazy_load_members"
|
||||
LazyLoadCacheMaxEntries = 128
|
||||
LazyLoadCacheMaxUserEntries = 128
|
||||
LazyLoadCacheMutable = true
|
||||
LazyLoadCacheMaxAge = time.Minute * 30
|
||||
)
|
||||
|
||||
type LazyLoadCache struct {
|
||||
// InMemoryLRUCachePartition containing other InMemoryLRUCachePartitions
|
||||
// with the actual cached members
|
||||
userCaches *InMemoryLRUCachePartition
|
||||
}
|
||||
|
||||
// NewLazyLoadCache creates a new LazyLoadCache.
|
||||
func NewLazyLoadCache() (*LazyLoadCache, error) {
|
||||
cache, err := NewInMemoryLRUCachePartition(
|
||||
LazyLoadCacheName,
|
||||
LazyLoadCacheMutable,
|
||||
LazyLoadCacheMaxEntries,
|
||||
LazyLoadCacheMaxAge,
|
||||
true,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go cacheCleaner(cache)
|
||||
return &LazyLoadCache{
|
||||
userCaches: cache,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *LazyLoadCache) lazyLoadCacheForUser(device *userapi.Device) (*InMemoryLRUCachePartition, error) {
|
||||
cacheName := fmt.Sprintf("%s/%s", device.UserID, device.ID)
|
||||
userCache, ok := c.userCaches.Get(cacheName)
|
||||
if ok && userCache != nil {
|
||||
if cache, ok := userCache.(*InMemoryLRUCachePartition); ok {
|
||||
return cache, nil
|
||||
}
|
||||
}
|
||||
cache, err := NewInMemoryLRUCachePartition(
|
||||
LazyLoadCacheName,
|
||||
LazyLoadCacheMutable,
|
||||
LazyLoadCacheMaxUserEntries,
|
||||
LazyLoadCacheMaxAge,
|
||||
false,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.userCaches.Set(cacheName, cache)
|
||||
go cacheCleaner(cache)
|
||||
return cache, nil
|
||||
}
|
||||
|
||||
func (c *LazyLoadCache) StoreLazyLoadedUser(device *userapi.Device, roomID, userID, eventID string) {
|
||||
cache, err := c.lazyLoadCacheForUser(device)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cacheKey := fmt.Sprintf("%s/%s/%s/%s", device.UserID, device.ID, roomID, userID)
|
||||
cache.Set(cacheKey, eventID)
|
||||
}
|
||||
|
||||
func (c *LazyLoadCache) IsLazyLoadedUserCached(device *userapi.Device, roomID, userID string) (string, bool) {
|
||||
cache, err := c.lazyLoadCacheForUser(device)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
cacheKey := fmt.Sprintf("%s/%s/%s/%s", device.UserID, device.ID, roomID, userID)
|
||||
val, ok := cache.Get(cacheKey)
|
||||
if !ok {
|
||||
return "", ok
|
||||
}
|
||||
return val.(string), ok
|
||||
}
|
||||
|
|
@ -3,8 +3,6 @@ package pushgateway
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// A Client is how interactions with a Push Gateway is done.
|
||||
|
|
@ -50,7 +48,7 @@ type Device struct {
|
|||
AppID string `json:"app_id"` // Required
|
||||
Data map[string]interface{} `json:"data"` // Required. UNSPEC: Sytests require this to allow unknown keys.
|
||||
PushKey string `json:"pushkey"` // Required
|
||||
PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"`
|
||||
PushKeyTS int64 `json:"pushkey_ts,omitempty"`
|
||||
Tweaks map[string]interface{} `json:"tweaks,omitempty"`
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ func AddPublicRoutes(
|
|||
userAPI userapi.UserInternalAPI,
|
||||
client *gomatrixserverlib.Client,
|
||||
) {
|
||||
mediaDB, err := storage.Open(&cfg.Database)
|
||||
mediaDB, err := storage.NewMediaAPIDatasource(&cfg.Database)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panicf("failed to connect to media db")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
|
|
@ -311,6 +312,26 @@ func (r *uploadRequest) storeFileAndMetadata(
|
|||
}
|
||||
|
||||
go func() {
|
||||
file, err := os.Open(string(finalPath))
|
||||
if err != nil {
|
||||
r.Logger.WithError(err).Error("unable to open file")
|
||||
return
|
||||
}
|
||||
defer file.Close() // nolint: errcheck
|
||||
// http.DetectContentType only needs 512 bytes
|
||||
buf := make([]byte, 512)
|
||||
_, err = file.Read(buf)
|
||||
if err != nil {
|
||||
r.Logger.WithError(err).Error("unable to read file")
|
||||
return
|
||||
}
|
||||
// Check if we need to generate thumbnails
|
||||
fileType := http.DetectContentType(buf)
|
||||
if !strings.HasPrefix(fileType, "image") {
|
||||
r.Logger.WithField("contentType", fileType).Debugf("uploaded file is not an image or can not be thumbnailed, not generating thumbnails")
|
||||
return
|
||||
}
|
||||
|
||||
busy, err := thumbnailer.GenerateThumbnails(
|
||||
context.Background(), finalPath, thumbnailSizes, r.MediaMetadata,
|
||||
activeThumbnailGeneration, maxThumbnailGenerators, db, r.Logger,
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ func Test_uploadRequest_doUpload(t *testing.T) {
|
|||
_ = os.Mkdir(testdataPath, os.ModePerm)
|
||||
defer fileutils.RemoveDir(types.Path(testdataPath), nil)
|
||||
|
||||
db, err := storage.Open(&config.DatabaseOptions{
|
||||
db, err := storage.NewMediaAPIDatasource(&config.DatabaseOptions{
|
||||
ConnectionString: "file::memory:?cache=shared",
|
||||
MaxOpenConnections: 100,
|
||||
MaxIdleConnections: 2,
|
||||
|
|
|
|||
|
|
@ -22,9 +22,17 @@ import (
|
|||
)
|
||||
|
||||
type Database interface {
|
||||
MediaRepository
|
||||
Thumbnails
|
||||
}
|
||||
|
||||
type MediaRepository interface {
|
||||
StoreMediaMetadata(ctx context.Context, mediaMetadata *types.MediaMetadata) error
|
||||
GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error)
|
||||
GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error)
|
||||
}
|
||||
|
||||
type Thumbnails interface {
|
||||
StoreThumbnail(ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata) error
|
||||
GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error)
|
||||
GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) ([]*types.ThumbnailMetadata, error)
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ import (
|
|||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
|
@ -69,24 +71,25 @@ type mediaStatements struct {
|
|||
selectMediaByHashStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *mediaStatements) prepare(db *sql.DB) (err error) {
|
||||
_, err = db.Exec(mediaSchema)
|
||||
func NewPostgresMediaRepositoryTable(db *sql.DB) (tables.MediaRepository, error) {
|
||||
s := &mediaStatements{}
|
||||
_, err := db.Exec(mediaSchema)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return statementList{
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertMediaStmt, insertMediaSQL},
|
||||
{&s.selectMediaStmt, selectMediaSQL},
|
||||
{&s.selectMediaByHashStmt, selectMediaByHashSQL},
|
||||
}.prepare(db)
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *mediaStatements) insertMedia(
|
||||
ctx context.Context, mediaMetadata *types.MediaMetadata,
|
||||
func (s *mediaStatements) InsertMedia(
|
||||
ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata,
|
||||
) error {
|
||||
mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
|
||||
_, err := s.insertMediaStmt.ExecContext(
|
||||
mediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
|
||||
_, err := sqlutil.TxStmtContext(ctx, txn, s.insertMediaStmt).ExecContext(
|
||||
ctx,
|
||||
mediaMetadata.MediaID,
|
||||
mediaMetadata.Origin,
|
||||
|
|
@ -100,14 +103,14 @@ func (s *mediaStatements) insertMedia(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *mediaStatements) selectMedia(
|
||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
func (s *mediaStatements) SelectMedia(
|
||||
ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) (*types.MediaMetadata, error) {
|
||||
mediaMetadata := types.MediaMetadata{
|
||||
MediaID: mediaID,
|
||||
Origin: mediaOrigin,
|
||||
}
|
||||
err := s.selectMediaStmt.QueryRowContext(
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaStmt).QueryRowContext(
|
||||
ctx, mediaMetadata.MediaID, mediaMetadata.Origin,
|
||||
).Scan(
|
||||
&mediaMetadata.ContentType,
|
||||
|
|
@ -120,14 +123,14 @@ func (s *mediaStatements) selectMedia(
|
|||
return &mediaMetadata, err
|
||||
}
|
||||
|
||||
func (s *mediaStatements) selectMediaByHash(
|
||||
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
|
||||
func (s *mediaStatements) SelectMediaByHash(
|
||||
ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) (*types.MediaMetadata, error) {
|
||||
mediaMetadata := types.MediaMetadata{
|
||||
Base64Hash: mediaHash,
|
||||
Origin: mediaOrigin,
|
||||
}
|
||||
err := s.selectMediaStmt.QueryRowContext(
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaByHashStmt).QueryRowContext(
|
||||
ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin,
|
||||
).Scan(
|
||||
&mediaMetadata.ContentType,
|
||||
|
|
|
|||
46
mediaapi/storage/postgres/mediaapi.go
Normal file
46
mediaapi/storage/postgres/mediaapi.go
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
// Import the postgres database driver.
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/mediaapi/storage/shared"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
)
|
||||
|
||||
// NewDatabase opens a postgres database.
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) {
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mediaRepo, err := NewPostgresMediaRepositoryTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
thumbnails, err := NewPostgresThumbnailsTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &shared.Database{
|
||||
MediaRepository: mediaRepo,
|
||||
Thumbnails: thumbnails,
|
||||
DB: db,
|
||||
Writer: sqlutil.NewExclusiveWriter(),
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// FIXME: This should be made internal!
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement.
|
||||
type statementList []struct {
|
||||
statement **sql.Stmt
|
||||
sql string
|
||||
}
|
||||
|
||||
// prepare the SQL for each statement in the list and assign the result to the prepared statement.
|
||||
func (s statementList) prepare(db *sql.DB) (err error) {
|
||||
for _, statement := range s {
|
||||
if *statement.statement, err = db.Prepare(statement.sql); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
@ -1,36 +0,0 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
type statements struct {
|
||||
media mediaStatements
|
||||
thumbnail thumbnailStatements
|
||||
}
|
||||
|
||||
func (s *statements) prepare(db *sql.DB) (err error) {
|
||||
if err = s.media.prepare(db); err != nil {
|
||||
return
|
||||
}
|
||||
if err = s.thumbnail.prepare(db); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
|
@ -21,6 +21,8 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
|
@ -63,7 +65,7 @@ SELECT content_type, file_size_bytes, creation_ts FROM mediaapi_thumbnail WHERE
|
|||
|
||||
// Note: this selects all thumbnails for a media_origin and media_id
|
||||
const selectThumbnailsSQL = `
|
||||
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2
|
||||
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 ORDER BY creation_ts ASC
|
||||
`
|
||||
|
||||
type thumbnailStatements struct {
|
||||
|
|
@ -72,24 +74,25 @@ type thumbnailStatements struct {
|
|||
selectThumbnailsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *thumbnailStatements) prepare(db *sql.DB) (err error) {
|
||||
_, err = db.Exec(thumbnailSchema)
|
||||
func NewPostgresThumbnailsTable(db *sql.DB) (tables.Thumbnails, error) {
|
||||
s := &thumbnailStatements{}
|
||||
_, err := db.Exec(thumbnailSchema)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return statementList{
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertThumbnailStmt, insertThumbnailSQL},
|
||||
{&s.selectThumbnailStmt, selectThumbnailSQL},
|
||||
{&s.selectThumbnailsStmt, selectThumbnailsSQL},
|
||||
}.prepare(db)
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *thumbnailStatements) insertThumbnail(
|
||||
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
|
||||
func (s *thumbnailStatements) InsertThumbnail(
|
||||
ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata,
|
||||
) error {
|
||||
thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
|
||||
_, err := s.insertThumbnailStmt.ExecContext(
|
||||
thumbnailMetadata.MediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
|
||||
_, err := sqlutil.TxStmtContext(ctx, txn, s.insertThumbnailStmt).ExecContext(
|
||||
ctx,
|
||||
thumbnailMetadata.MediaMetadata.MediaID,
|
||||
thumbnailMetadata.MediaMetadata.Origin,
|
||||
|
|
@ -103,8 +106,9 @@ func (s *thumbnailStatements) insertThumbnail(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *thumbnailStatements) selectThumbnail(
|
||||
func (s *thumbnailStatements) SelectThumbnail(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
mediaID types.MediaID,
|
||||
mediaOrigin gomatrixserverlib.ServerName,
|
||||
width, height int,
|
||||
|
|
@ -121,7 +125,7 @@ func (s *thumbnailStatements) selectThumbnail(
|
|||
ResizeMethod: resizeMethod,
|
||||
},
|
||||
}
|
||||
err := s.selectThumbnailStmt.QueryRowContext(
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailStmt).QueryRowContext(
|
||||
ctx,
|
||||
thumbnailMetadata.MediaMetadata.MediaID,
|
||||
thumbnailMetadata.MediaMetadata.Origin,
|
||||
|
|
@ -136,10 +140,10 @@ func (s *thumbnailStatements) selectThumbnail(
|
|||
return &thumbnailMetadata, err
|
||||
}
|
||||
|
||||
func (s *thumbnailStatements) selectThumbnails(
|
||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
func (s *thumbnailStatements) SelectThumbnails(
|
||||
ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) ([]*types.ThumbnailMetadata, error) {
|
||||
rows, err := s.selectThumbnailsStmt.QueryContext(
|
||||
rows, err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailsStmt).QueryContext(
|
||||
ctx, mediaID, mediaOrigin,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
|
|
@ -13,54 +12,38 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
package shared
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
// Import the postgres database driver.
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// Database is used to store metadata about a repository of media files.
|
||||
type Database struct {
|
||||
statements statements
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// Open opens a postgres database.
|
||||
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||
var d Database
|
||||
var err error
|
||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.statements.prepare(d.db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &d, nil
|
||||
DB *sql.DB
|
||||
Writer sqlutil.Writer
|
||||
MediaRepository tables.MediaRepository
|
||||
Thumbnails tables.Thumbnails
|
||||
}
|
||||
|
||||
// StoreMediaMetadata inserts the metadata about the uploaded media into the database.
|
||||
// Returns an error if the combination of MediaID and Origin are not unique in the table.
|
||||
func (d *Database) StoreMediaMetadata(
|
||||
ctx context.Context, mediaMetadata *types.MediaMetadata,
|
||||
) error {
|
||||
return d.statements.media.insertMedia(ctx, mediaMetadata)
|
||||
func (d Database) StoreMediaMetadata(ctx context.Context, mediaMetadata *types.MediaMetadata) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.MediaRepository.InsertMedia(ctx, txn, mediaMetadata)
|
||||
})
|
||||
}
|
||||
|
||||
// GetMediaMetadata returns metadata about media stored on this server.
|
||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
||||
// Returns nil metadata if there is no metadata associated with this media.
|
||||
func (d *Database) GetMediaMetadata(
|
||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) (*types.MediaMetadata, error) {
|
||||
mediaMetadata, err := d.statements.media.selectMedia(ctx, mediaID, mediaOrigin)
|
||||
func (d Database) GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) {
|
||||
mediaMetadata, err := d.MediaRepository.SelectMedia(ctx, nil, mediaID, mediaOrigin)
|
||||
if err != nil && err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
|
|
@ -70,10 +53,8 @@ func (d *Database) GetMediaMetadata(
|
|||
// GetMediaMetadataByHash returns metadata about media stored on this server.
|
||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
||||
// Returns nil metadata if there is no metadata associated with this media.
|
||||
func (d *Database) GetMediaMetadataByHash(
|
||||
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) (*types.MediaMetadata, error) {
|
||||
mediaMetadata, err := d.statements.media.selectMediaByHash(ctx, mediaHash, mediaOrigin)
|
||||
func (d Database) GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) {
|
||||
mediaMetadata, err := d.MediaRepository.SelectMediaByHash(ctx, nil, mediaHash, mediaOrigin)
|
||||
if err != nil && err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
|
|
@ -82,40 +63,36 @@ func (d *Database) GetMediaMetadataByHash(
|
|||
|
||||
// StoreThumbnail inserts the metadata about the thumbnail into the database.
|
||||
// Returns an error if the combination of MediaID and Origin are not unique in the table.
|
||||
func (d *Database) StoreThumbnail(
|
||||
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
|
||||
) error {
|
||||
return d.statements.thumbnail.insertThumbnail(ctx, thumbnailMetadata)
|
||||
func (d Database) StoreThumbnail(ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.Thumbnails.InsertThumbnail(ctx, txn, thumbnailMetadata)
|
||||
})
|
||||
}
|
||||
|
||||
// GetThumbnail returns metadata about a specific thumbnail.
|
||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
||||
// Returns nil metadata if there is no metadata associated with this thumbnail.
|
||||
func (d *Database) GetThumbnail(
|
||||
ctx context.Context,
|
||||
mediaID types.MediaID,
|
||||
mediaOrigin gomatrixserverlib.ServerName,
|
||||
width, height int,
|
||||
resizeMethod string,
|
||||
) (*types.ThumbnailMetadata, error) {
|
||||
thumbnailMetadata, err := d.statements.thumbnail.selectThumbnail(
|
||||
ctx, mediaID, mediaOrigin, width, height, resizeMethod,
|
||||
)
|
||||
if err != nil && err == sql.ErrNoRows {
|
||||
func (d Database) GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error) {
|
||||
metadata, err := d.Thumbnails.SelectThumbnail(ctx, nil, mediaID, mediaOrigin, width, height, resizeMethod)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return thumbnailMetadata, err
|
||||
return nil, err
|
||||
}
|
||||
return metadata, err
|
||||
}
|
||||
|
||||
// GetThumbnails returns metadata about all thumbnails for a specific media stored on this server.
|
||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
||||
// Returns nil metadata if there are no thumbnails associated with this media.
|
||||
func (d *Database) GetThumbnails(
|
||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) ([]*types.ThumbnailMetadata, error) {
|
||||
thumbnails, err := d.statements.thumbnail.selectThumbnails(ctx, mediaID, mediaOrigin)
|
||||
if err != nil && err == sql.ErrNoRows {
|
||||
func (d Database) GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) ([]*types.ThumbnailMetadata, error) {
|
||||
metadatas, err := d.Thumbnails.SelectThumbnails(ctx, nil, mediaID, mediaOrigin)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return thumbnails, err
|
||||
return nil, err
|
||||
}
|
||||
return metadatas, err
|
||||
}
|
||||
|
|
@ -21,6 +21,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
|
@ -66,35 +67,32 @@ SELECT content_type, file_size_bytes, creation_ts, upload_name, media_id, user_i
|
|||
|
||||
type mediaStatements struct {
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
insertMediaStmt *sql.Stmt
|
||||
selectMediaStmt *sql.Stmt
|
||||
selectMediaByHashStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *mediaStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
||||
s.db = db
|
||||
s.writer = writer
|
||||
|
||||
_, err = db.Exec(mediaSchema)
|
||||
func NewSQLiteMediaRepositoryTable(db *sql.DB) (tables.MediaRepository, error) {
|
||||
s := &mediaStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(mediaSchema)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return statementList{
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertMediaStmt, insertMediaSQL},
|
||||
{&s.selectMediaStmt, selectMediaSQL},
|
||||
{&s.selectMediaByHashStmt, selectMediaByHashSQL},
|
||||
}.prepare(db)
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *mediaStatements) insertMedia(
|
||||
ctx context.Context, mediaMetadata *types.MediaMetadata,
|
||||
func (s *mediaStatements) InsertMedia(
|
||||
ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata,
|
||||
) error {
|
||||
mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
|
||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertMediaStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
mediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
|
||||
_, err := sqlutil.TxStmtContext(ctx, txn, s.insertMediaStmt).ExecContext(
|
||||
ctx,
|
||||
mediaMetadata.MediaID,
|
||||
mediaMetadata.Origin,
|
||||
|
|
@ -106,17 +104,16 @@ func (s *mediaStatements) insertMedia(
|
|||
mediaMetadata.UserID,
|
||||
)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *mediaStatements) selectMedia(
|
||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
func (s *mediaStatements) SelectMedia(
|
||||
ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) (*types.MediaMetadata, error) {
|
||||
mediaMetadata := types.MediaMetadata{
|
||||
MediaID: mediaID,
|
||||
Origin: mediaOrigin,
|
||||
}
|
||||
err := s.selectMediaStmt.QueryRowContext(
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaStmt).QueryRowContext(
|
||||
ctx, mediaMetadata.MediaID, mediaMetadata.Origin,
|
||||
).Scan(
|
||||
&mediaMetadata.ContentType,
|
||||
|
|
@ -129,14 +126,14 @@ func (s *mediaStatements) selectMedia(
|
|||
return &mediaMetadata, err
|
||||
}
|
||||
|
||||
func (s *mediaStatements) selectMediaByHash(
|
||||
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
|
||||
func (s *mediaStatements) SelectMediaByHash(
|
||||
ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) (*types.MediaMetadata, error) {
|
||||
mediaMetadata := types.MediaMetadata{
|
||||
Base64Hash: mediaHash,
|
||||
Origin: mediaOrigin,
|
||||
}
|
||||
err := s.selectMediaStmt.QueryRowContext(
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaByHashStmt).QueryRowContext(
|
||||
ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin,
|
||||
).Scan(
|
||||
&mediaMetadata.ContentType,
|
||||
|
|
|
|||
|
|
@ -16,23 +16,30 @@
|
|||
package sqlite3
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
// Import the postgres database driver.
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/mediaapi/storage/shared"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
)
|
||||
|
||||
type statements struct {
|
||||
media mediaStatements
|
||||
thumbnail thumbnailStatements
|
||||
}
|
||||
|
||||
func (s *statements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
||||
if err = s.media.prepare(db, writer); err != nil {
|
||||
return
|
||||
// NewDatabase opens a SQLIte database.
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) {
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = s.thumbnail.prepare(db, writer); err != nil {
|
||||
return
|
||||
mediaRepo, err := NewSQLiteMediaRepositoryTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return
|
||||
thumbnails, err := NewSQLiteThumbnailsTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &shared.Database{
|
||||
MediaRepository: mediaRepo,
|
||||
Thumbnails: thumbnails,
|
||||
DB: db,
|
||||
Writer: sqlutil.NewExclusiveWriter(),
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// FIXME: This should be made internal!
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement.
|
||||
type statementList []struct {
|
||||
statement **sql.Stmt
|
||||
sql string
|
||||
}
|
||||
|
||||
// prepare the SQL for each statement in the list and assign the result to the prepared statement.
|
||||
func (s statementList) prepare(db *sql.DB) (err error) {
|
||||
for _, statement := range s {
|
||||
if *statement.statement, err = db.Prepare(statement.sql); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
@ -1,123 +0,0 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
// Import the postgres database driver.
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// Database is used to store metadata about a repository of media files.
|
||||
type Database struct {
|
||||
statements statements
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
}
|
||||
|
||||
// Open opens a postgres database.
|
||||
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||
d := Database{
|
||||
writer: sqlutil.NewExclusiveWriter(),
|
||||
}
|
||||
var err error
|
||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.statements.prepare(d.db, d.writer); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &d, nil
|
||||
}
|
||||
|
||||
// StoreMediaMetadata inserts the metadata about the uploaded media into the database.
|
||||
// Returns an error if the combination of MediaID and Origin are not unique in the table.
|
||||
func (d *Database) StoreMediaMetadata(
|
||||
ctx context.Context, mediaMetadata *types.MediaMetadata,
|
||||
) error {
|
||||
return d.statements.media.insertMedia(ctx, mediaMetadata)
|
||||
}
|
||||
|
||||
// GetMediaMetadata returns metadata about media stored on this server.
|
||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
||||
// Returns nil metadata if there is no metadata associated with this media.
|
||||
func (d *Database) GetMediaMetadata(
|
||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) (*types.MediaMetadata, error) {
|
||||
mediaMetadata, err := d.statements.media.selectMedia(ctx, mediaID, mediaOrigin)
|
||||
if err != nil && err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return mediaMetadata, err
|
||||
}
|
||||
|
||||
// GetMediaMetadataByHash returns metadata about media stored on this server.
|
||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
||||
// Returns nil metadata if there is no metadata associated with this media.
|
||||
func (d *Database) GetMediaMetadataByHash(
|
||||
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) (*types.MediaMetadata, error) {
|
||||
mediaMetadata, err := d.statements.media.selectMediaByHash(ctx, mediaHash, mediaOrigin)
|
||||
if err != nil && err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return mediaMetadata, err
|
||||
}
|
||||
|
||||
// StoreThumbnail inserts the metadata about the thumbnail into the database.
|
||||
// Returns an error if the combination of MediaID and Origin are not unique in the table.
|
||||
func (d *Database) StoreThumbnail(
|
||||
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
|
||||
) error {
|
||||
return d.statements.thumbnail.insertThumbnail(ctx, thumbnailMetadata)
|
||||
}
|
||||
|
||||
// GetThumbnail returns metadata about a specific thumbnail.
|
||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
||||
// Returns nil metadata if there is no metadata associated with this thumbnail.
|
||||
func (d *Database) GetThumbnail(
|
||||
ctx context.Context,
|
||||
mediaID types.MediaID,
|
||||
mediaOrigin gomatrixserverlib.ServerName,
|
||||
width, height int,
|
||||
resizeMethod string,
|
||||
) (*types.ThumbnailMetadata, error) {
|
||||
thumbnailMetadata, err := d.statements.thumbnail.selectThumbnail(
|
||||
ctx, mediaID, mediaOrigin, width, height, resizeMethod,
|
||||
)
|
||||
if err != nil && err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return thumbnailMetadata, err
|
||||
}
|
||||
|
||||
// GetThumbnails returns metadata about all thumbnails for a specific media stored on this server.
|
||||
// The media could have been uploaded to this server or fetched from another server and cached here.
|
||||
// Returns nil metadata if there are no thumbnails associated with this media.
|
||||
func (d *Database) GetThumbnails(
|
||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) ([]*types.ThumbnailMetadata, error) {
|
||||
thumbnails, err := d.statements.thumbnail.selectThumbnails(ctx, mediaID, mediaOrigin)
|
||||
if err != nil && err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return thumbnails, err
|
||||
}
|
||||
|
|
@ -22,6 +22,7 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
|
@ -54,39 +55,32 @@ SELECT content_type, file_size_bytes, creation_ts FROM mediaapi_thumbnail WHERE
|
|||
|
||||
// Note: this selects all thumbnails for a media_origin and media_id
|
||||
const selectThumbnailsSQL = `
|
||||
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2
|
||||
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 ORDER BY creation_ts ASC
|
||||
`
|
||||
|
||||
type thumbnailStatements struct {
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
insertThumbnailStmt *sql.Stmt
|
||||
selectThumbnailStmt *sql.Stmt
|
||||
selectThumbnailsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *thumbnailStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
||||
_, err = db.Exec(thumbnailSchema)
|
||||
func NewSQLiteThumbnailsTable(db *sql.DB) (tables.Thumbnails, error) {
|
||||
s := &thumbnailStatements{}
|
||||
_, err := db.Exec(thumbnailSchema)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
s.db = db
|
||||
s.writer = writer
|
||||
|
||||
return statementList{
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertThumbnailStmt, insertThumbnailSQL},
|
||||
{&s.selectThumbnailStmt, selectThumbnailSQL},
|
||||
{&s.selectThumbnailsStmt, selectThumbnailsSQL},
|
||||
}.prepare(db)
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *thumbnailStatements) insertThumbnail(
|
||||
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
|
||||
) error {
|
||||
thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
|
||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertThumbnailStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
func (s *thumbnailStatements) InsertThumbnail(ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata) error {
|
||||
thumbnailMetadata.MediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
|
||||
_, err := sqlutil.TxStmtContext(ctx, txn, s.insertThumbnailStmt).ExecContext(
|
||||
ctx,
|
||||
thumbnailMetadata.MediaMetadata.MediaID,
|
||||
thumbnailMetadata.MediaMetadata.Origin,
|
||||
|
|
@ -98,11 +92,11 @@ func (s *thumbnailStatements) insertThumbnail(
|
|||
thumbnailMetadata.ThumbnailSize.ResizeMethod,
|
||||
)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *thumbnailStatements) selectThumbnail(
|
||||
func (s *thumbnailStatements) SelectThumbnail(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
mediaID types.MediaID,
|
||||
mediaOrigin gomatrixserverlib.ServerName,
|
||||
width, height int,
|
||||
|
|
@ -119,7 +113,7 @@ func (s *thumbnailStatements) selectThumbnail(
|
|||
ResizeMethod: resizeMethod,
|
||||
},
|
||||
}
|
||||
err := s.selectThumbnailStmt.QueryRowContext(
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailStmt).QueryRowContext(
|
||||
ctx,
|
||||
thumbnailMetadata.MediaMetadata.MediaID,
|
||||
thumbnailMetadata.MediaMetadata.Origin,
|
||||
|
|
@ -134,10 +128,11 @@ func (s *thumbnailStatements) selectThumbnail(
|
|||
return &thumbnailMetadata, err
|
||||
}
|
||||
|
||||
func (s *thumbnailStatements) selectThumbnails(
|
||||
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
func (s *thumbnailStatements) SelectThumbnails(
|
||||
ctx context.Context, txn *sql.Tx, mediaID types.MediaID,
|
||||
mediaOrigin gomatrixserverlib.ServerName,
|
||||
) ([]*types.ThumbnailMetadata, error) {
|
||||
rows, err := s.selectThumbnailsStmt.QueryContext(
|
||||
rows, err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailsStmt).QueryContext(
|
||||
ctx, mediaID, mediaOrigin,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -25,13 +25,13 @@ import (
|
|||
"github.com/matrix-org/dendrite/setup/config"
|
||||
)
|
||||
|
||||
// Open opens a postgres database.
|
||||
func Open(dbProperties *config.DatabaseOptions) (Database, error) {
|
||||
// NewMediaAPIDatasource opens a database connection.
|
||||
func NewMediaAPIDatasource(dbProperties *config.DatabaseOptions) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.Open(dbProperties)
|
||||
return sqlite3.NewDatabase(dbProperties)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
return postgres.Open(dbProperties)
|
||||
return postgres.NewDatabase(dbProperties)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected database type")
|
||||
}
|
||||
|
|
|
|||
135
mediaapi/storage/storage_test.go
Normal file
135
mediaapi/storage/storage_test.go
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
package storage_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/mediaapi/storage"
|
||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
)
|
||||
|
||||
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := storage.NewMediaAPIDatasource(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewSyncServerDatasource returned %s", err)
|
||||
}
|
||||
return db, close
|
||||
}
|
||||
func TestMediaRepository(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
ctx := context.Background()
|
||||
t.Run("can insert media & query media", func(t *testing.T) {
|
||||
metadata := &types.MediaMetadata{
|
||||
MediaID: "testing",
|
||||
Origin: "localhost",
|
||||
ContentType: "image/png",
|
||||
FileSizeBytes: 10,
|
||||
UploadName: "upload test",
|
||||
Base64Hash: "dGVzdGluZw==",
|
||||
UserID: "@alice:localhost",
|
||||
}
|
||||
if err := db.StoreMediaMetadata(ctx, metadata); err != nil {
|
||||
t.Fatalf("unable to store media metadata: %v", err)
|
||||
}
|
||||
// query by media id
|
||||
gotMetadata, err := db.GetMediaMetadata(ctx, metadata.MediaID, metadata.Origin)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to query media metadata: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(metadata, gotMetadata) {
|
||||
t.Fatalf("expected metadata %+v, got %v", metadata, gotMetadata)
|
||||
}
|
||||
// query by media hash
|
||||
gotMetadata, err = db.GetMediaMetadataByHash(ctx, metadata.Base64Hash, metadata.Origin)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to query media metadata by hash: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(metadata, gotMetadata) {
|
||||
t.Fatalf("expected metadata %+v, got %v", metadata, gotMetadata)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestThumbnailsStorage(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
ctx := context.Background()
|
||||
t.Run("can insert thumbnails & query media", func(t *testing.T) {
|
||||
thumbnails := []*types.ThumbnailMetadata{
|
||||
{
|
||||
MediaMetadata: &types.MediaMetadata{
|
||||
MediaID: "testing",
|
||||
Origin: "localhost",
|
||||
ContentType: "image/png",
|
||||
FileSizeBytes: 6,
|
||||
},
|
||||
ThumbnailSize: types.ThumbnailSize{
|
||||
Width: 5,
|
||||
Height: 5,
|
||||
ResizeMethod: types.Crop,
|
||||
},
|
||||
},
|
||||
{
|
||||
MediaMetadata: &types.MediaMetadata{
|
||||
MediaID: "testing",
|
||||
Origin: "localhost",
|
||||
ContentType: "image/png",
|
||||
FileSizeBytes: 7,
|
||||
},
|
||||
ThumbnailSize: types.ThumbnailSize{
|
||||
Width: 1,
|
||||
Height: 1,
|
||||
ResizeMethod: types.Scale,
|
||||
},
|
||||
},
|
||||
}
|
||||
for i := range thumbnails {
|
||||
if err := db.StoreThumbnail(ctx, thumbnails[i]); err != nil {
|
||||
t.Fatalf("unable to store thumbnail metadata: %v", err)
|
||||
}
|
||||
}
|
||||
// query by single thumbnail
|
||||
gotMetadata, err := db.GetThumbnail(ctx,
|
||||
thumbnails[0].MediaMetadata.MediaID,
|
||||
thumbnails[0].MediaMetadata.Origin,
|
||||
thumbnails[0].ThumbnailSize.Width, thumbnails[0].ThumbnailSize.Height,
|
||||
thumbnails[0].ThumbnailSize.ResizeMethod,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to query thumbnail metadata: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata) {
|
||||
t.Fatalf("expected metadata %+v, got %+v", thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata)
|
||||
}
|
||||
if !reflect.DeepEqual(thumbnails[0].ThumbnailSize, gotMetadata.ThumbnailSize) {
|
||||
t.Fatalf("expected metadata %+v, got %+v", thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata)
|
||||
}
|
||||
// query by all thumbnails
|
||||
gotMediadatas, err := db.GetThumbnails(ctx, thumbnails[0].MediaMetadata.MediaID, thumbnails[0].MediaMetadata.Origin)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to query media metadata by hash: %v", err)
|
||||
}
|
||||
if len(gotMediadatas) != len(thumbnails) {
|
||||
t.Fatalf("expected %d stored thumbnail metadata, got %d", len(thumbnails), len(gotMediadatas))
|
||||
}
|
||||
for i := range gotMediadatas {
|
||||
if !reflect.DeepEqual(thumbnails[i].MediaMetadata, gotMediadatas[i].MediaMetadata) {
|
||||
t.Fatalf("expected metadata %+v, got %v", thumbnails[i].MediaMetadata, gotMediadatas[i].MediaMetadata)
|
||||
}
|
||||
if !reflect.DeepEqual(thumbnails[i].ThumbnailSize, gotMediadatas[i].ThumbnailSize) {
|
||||
t.Fatalf("expected metadata %+v, got %v", thumbnails[i].ThumbnailSize, gotMediadatas[i].ThumbnailSize)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
|
@ -22,10 +22,10 @@ import (
|
|||
)
|
||||
|
||||
// Open opens a postgres database.
|
||||
func Open(dbProperties *config.DatabaseOptions) (Database, error) {
|
||||
func NewMediaAPIDatasource(dbProperties *config.DatabaseOptions) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.Open(dbProperties)
|
||||
return sqlite3.NewDatabase(dbProperties)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
return nil, fmt.Errorf("can't use Postgres implementation")
|
||||
default:
|
||||
|
|
|
|||
46
mediaapi/storage/tables/interface.go
Normal file
46
mediaapi/storage/tables/interface.go
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type Thumbnails interface {
|
||||
InsertThumbnail(ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata) error
|
||||
SelectThumbnail(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
|
||||
width, height int,
|
||||
resizeMethod string,
|
||||
) (*types.ThumbnailMetadata, error)
|
||||
SelectThumbnails(
|
||||
ctx context.Context, txn *sql.Tx, mediaID types.MediaID,
|
||||
mediaOrigin gomatrixserverlib.ServerName,
|
||||
) ([]*types.ThumbnailMetadata, error)
|
||||
}
|
||||
|
||||
type MediaRepository interface {
|
||||
InsertMedia(ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata) error
|
||||
SelectMedia(ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error)
|
||||
SelectMediaByHash(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
|
||||
) (*types.MediaMetadata, error)
|
||||
}
|
||||
|
|
@ -45,16 +45,13 @@ type RequestMethod string
|
|||
// MatrixUserID is a Matrix user ID string in the form @user:domain e.g. @alice:matrix.org
|
||||
type MatrixUserID string
|
||||
|
||||
// UnixMs is the milliseconds since the Unix epoch
|
||||
type UnixMs int64
|
||||
|
||||
// MediaMetadata is metadata associated with a media file
|
||||
type MediaMetadata struct {
|
||||
MediaID MediaID
|
||||
Origin gomatrixserverlib.ServerName
|
||||
ContentType ContentType
|
||||
FileSizeBytes FileSizeBytes
|
||||
CreationTimestamp UnixMs
|
||||
CreationTimestamp gomatrixserverlib.Timestamp
|
||||
UploadName Filename
|
||||
Base64Hash Base64Hash
|
||||
UserID MatrixUserID
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ import (
|
|||
type Notifier struct {
|
||||
lock *sync.RWMutex
|
||||
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
|
||||
roomIDToJoinedUsers map[string]userIDSet
|
||||
roomIDToJoinedUsers map[string]*userIDSet
|
||||
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
|
||||
roomIDToPeekingDevices map[string]peekingDeviceSet
|
||||
// The latest sync position
|
||||
|
|
@ -54,7 +54,7 @@ type Notifier struct {
|
|||
// the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase).
|
||||
func NewNotifier() *Notifier {
|
||||
return &Notifier{
|
||||
roomIDToJoinedUsers: make(map[string]userIDSet),
|
||||
roomIDToJoinedUsers: make(map[string]*userIDSet),
|
||||
roomIDToPeekingDevices: make(map[string]peekingDeviceSet),
|
||||
userDeviceStreams: make(map[string]map[string]*UserDeviceStream),
|
||||
lock: &sync.RWMutex{},
|
||||
|
|
@ -262,7 +262,7 @@ func (n *Notifier) SharedUsers(userID string) []string {
|
|||
func (n *Notifier) _sharedUsers(userID string) []string {
|
||||
n._sharedUserMap[userID] = struct{}{}
|
||||
for roomID, users := range n.roomIDToJoinedUsers {
|
||||
if _, ok := users[userID]; !ok {
|
||||
if ok := users.isIn(userID); !ok {
|
||||
continue
|
||||
}
|
||||
for _, userID := range n._joinedUsers(roomID) {
|
||||
|
|
@ -282,8 +282,11 @@ func (n *Notifier) IsSharedUser(userA, userB string) bool {
|
|||
defer n.lock.RUnlock()
|
||||
var okA, okB bool
|
||||
for _, users := range n.roomIDToJoinedUsers {
|
||||
_, okA = users[userA]
|
||||
_, okB = users[userB]
|
||||
okA = users.isIn(userA)
|
||||
if !okA {
|
||||
continue
|
||||
}
|
||||
okB = users.isIn(userB)
|
||||
if okA && okB {
|
||||
return true
|
||||
}
|
||||
|
|
@ -345,11 +348,12 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) {
|
|||
// This is just the bulk form of addJoinedUser
|
||||
for roomID, userIDs := range roomIDToUserIDs {
|
||||
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
|
||||
n.roomIDToJoinedUsers[roomID] = make(userIDSet, len(userIDs))
|
||||
n.roomIDToJoinedUsers[roomID] = newUserIDSet(len(userIDs))
|
||||
}
|
||||
for _, userID := range userIDs {
|
||||
n.roomIDToJoinedUsers[roomID].add(userID)
|
||||
}
|
||||
n.roomIDToJoinedUsers[roomID].precompute()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -440,16 +444,18 @@ func (n *Notifier) _fetchUserStreams(userID string) []*UserDeviceStream {
|
|||
|
||||
func (n *Notifier) _addJoinedUser(roomID, userID string) {
|
||||
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
|
||||
n.roomIDToJoinedUsers[roomID] = make(userIDSet)
|
||||
n.roomIDToJoinedUsers[roomID] = newUserIDSet(8)
|
||||
}
|
||||
n.roomIDToJoinedUsers[roomID].add(userID)
|
||||
n.roomIDToJoinedUsers[roomID].precompute()
|
||||
}
|
||||
|
||||
func (n *Notifier) _removeJoinedUser(roomID, userID string) {
|
||||
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
|
||||
n.roomIDToJoinedUsers[roomID] = make(userIDSet)
|
||||
n.roomIDToJoinedUsers[roomID] = newUserIDSet(8)
|
||||
}
|
||||
n.roomIDToJoinedUsers[roomID].remove(userID)
|
||||
n.roomIDToJoinedUsers[roomID].precompute()
|
||||
}
|
||||
|
||||
func (n *Notifier) JoinedUsers(roomID string) (userIDs []string) {
|
||||
|
|
@ -521,19 +527,52 @@ func (n *Notifier) _removeEmptyUserStreams() {
|
|||
}
|
||||
|
||||
// A string set, mainly existing for improving clarity of structs in this file.
|
||||
type userIDSet map[string]struct{}
|
||||
|
||||
func (s userIDSet) add(str string) {
|
||||
s[str] = struct{}{}
|
||||
type userIDSet struct {
|
||||
sync.Mutex
|
||||
set map[string]struct{}
|
||||
precomputed []string
|
||||
}
|
||||
|
||||
func (s userIDSet) remove(str string) {
|
||||
delete(s, str)
|
||||
func newUserIDSet(cap int) *userIDSet {
|
||||
return &userIDSet{
|
||||
set: make(map[string]struct{}, cap),
|
||||
precomputed: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func (s userIDSet) values() (vals []string) {
|
||||
vals = make([]string, 0, len(s))
|
||||
for str := range s {
|
||||
func (s *userIDSet) add(str string) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
s.set[str] = struct{}{}
|
||||
s.precomputed = s.precomputed[:0] // invalidate cache
|
||||
}
|
||||
|
||||
func (s *userIDSet) remove(str string) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
delete(s.set, str)
|
||||
s.precomputed = s.precomputed[:0] // invalidate cache
|
||||
}
|
||||
|
||||
func (s *userIDSet) precompute() {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
s.precomputed = s.values()
|
||||
}
|
||||
|
||||
func (s *userIDSet) isIn(str string) bool {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
_, ok := s.set[str]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *userIDSet) values() (vals []string) {
|
||||
if len(s.precomputed) > 0 {
|
||||
return s.precomputed // only return if not invalidated
|
||||
}
|
||||
vals = make([]string, 0, len(s.set))
|
||||
for str := range s.set {
|
||||
vals = append(vals, str)
|
||||
}
|
||||
return
|
||||
|
|
|
|||
|
|
@ -60,7 +60,9 @@ func Context(
|
|||
Headers: nil,
|
||||
}
|
||||
}
|
||||
filter.Rooms = append(filter.Rooms, roomID)
|
||||
if filter.Rooms != nil {
|
||||
*filter.Rooms = append(*filter.Rooms, roomID)
|
||||
}
|
||||
|
||||
ctx := req.Context()
|
||||
membershipRes := roomserver.QueryMembershipForUserResponse{}
|
||||
|
|
|
|||
|
|
@ -262,12 +262,8 @@ func (r *messagesReq) retrieveEvents() (
|
|||
clientEvents []gomatrixserverlib.ClientEvent, start,
|
||||
end types.TopologyToken, err error,
|
||||
) {
|
||||
eventFilter := r.filter
|
||||
|
||||
// Retrieve the events from the local database.
|
||||
streamEvents, err := r.db.GetEventsInTopologicalRange(
|
||||
r.ctx, r.from, r.to, r.roomID, eventFilter.Limit, r.backwardOrdering,
|
||||
)
|
||||
streamEvents, err := r.db.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("GetEventsInRange: %w", err)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -104,8 +104,8 @@ type Database interface {
|
|||
// DeletePeek deletes all peeks for a given room by a given user
|
||||
// Returns an error if there was a problem communicating with the database.
|
||||
DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error)
|
||||
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit.
|
||||
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error)
|
||||
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last.
|
||||
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error)
|
||||
// EventPositionInTopology returns the depth and stream position of the given event.
|
||||
EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
|
||||
// BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events.
|
||||
|
|
|
|||
|
|
@ -47,14 +47,10 @@ const selectBackwardExtremitiesForRoomSQL = "" +
|
|||
const deleteBackwardExtremitySQL = "" +
|
||||
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
|
||||
|
||||
const deleteBackwardExtremitiesForRoomSQL = "" +
|
||||
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1"
|
||||
|
||||
type backwardExtremitiesStatements struct {
|
||||
insertBackwardExtremityStmt *sql.Stmt
|
||||
selectBackwardExtremitiesForRoomStmt *sql.Stmt
|
||||
deleteBackwardExtremityStmt *sql.Stmt
|
||||
deleteBackwardExtremitiesForRoomStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
|
||||
|
|
@ -72,9 +68,6 @@ func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremiti
|
|||
if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
|
|
@ -113,10 +106,3 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
|
|||
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID)
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -233,9 +233,10 @@ func (s *currentRoomStateStatements) SelectCurrentState(
|
|||
excludeEventIDs []string,
|
||||
) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt)
|
||||
senders, notSenders := getSendersStateFilterFilter(stateFilter)
|
||||
rows, err := stmt.QueryContext(ctx, roomID,
|
||||
pq.StringArray(stateFilter.Senders),
|
||||
pq.StringArray(stateFilter.NotSenders),
|
||||
pq.StringArray(senders),
|
||||
pq.StringArray(notSenders),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)),
|
||||
stateFilter.ContainsURL,
|
||||
|
|
|
|||
|
|
@ -16,21 +16,45 @@ package postgres
|
|||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// filterConvertWildcardToSQL converts wildcards as defined in
|
||||
// https://matrix.org/docs/spec/client_server/r0.3.0.html#post-matrix-client-r0-user-userid-filter
|
||||
// to SQL wildcards that can be used with LIKE()
|
||||
func filterConvertTypeWildcardToSQL(values []string) []string {
|
||||
func filterConvertTypeWildcardToSQL(values *[]string) []string {
|
||||
if values == nil {
|
||||
// Return nil instead of []string{} so IS NULL can work correctly when
|
||||
// the return value is passed into SQL queries
|
||||
return nil
|
||||
}
|
||||
|
||||
ret := make([]string, len(values))
|
||||
for i := range values {
|
||||
ret[i] = strings.Replace(values[i], "*", "%", -1)
|
||||
v := *values
|
||||
ret := make([]string, len(v))
|
||||
for i := range v {
|
||||
ret[i] = strings.Replace(v[i], "*", "%", -1)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// TODO: Replace when Dendrite uses Go 1.18
|
||||
func getSendersRoomEventFilter(filter *gomatrixserverlib.RoomEventFilter) (senders []string, notSenders []string) {
|
||||
if filter.Senders != nil {
|
||||
senders = *filter.Senders
|
||||
}
|
||||
if filter.NotSenders != nil {
|
||||
notSenders = *filter.NotSenders
|
||||
}
|
||||
return senders, notSenders
|
||||
}
|
||||
|
||||
func getSendersStateFilterFilter(filter *gomatrixserverlib.StateFilter) (senders []string, notSenders []string) {
|
||||
if filter.Senders != nil {
|
||||
senders = *filter.Senders
|
||||
}
|
||||
if filter.NotSenders != nil {
|
||||
notSenders = *filter.NotSenders
|
||||
}
|
||||
return senders, notSenders
|
||||
}
|
||||
|
|
|
|||
|
|
@ -56,12 +56,6 @@ const upsertMembershipSQL = "" +
|
|||
" ON CONFLICT ON CONSTRAINT syncapi_memberships_unique" +
|
||||
" DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6"
|
||||
|
||||
const selectMembershipSQL = "" +
|
||||
"SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" +
|
||||
" WHERE room_id = $1 AND user_id = $2 AND membership = ANY($3)" +
|
||||
" ORDER BY stream_pos DESC" +
|
||||
" LIMIT 1"
|
||||
|
||||
const selectMembershipCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM (" +
|
||||
" SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" +
|
||||
|
|
@ -69,7 +63,6 @@ const selectMembershipCountSQL = "" +
|
|||
|
||||
type membershipsStatements struct {
|
||||
upsertMembershipStmt *sql.Stmt
|
||||
selectMembershipStmt *sql.Stmt
|
||||
selectMembershipCountStmt *sql.Stmt
|
||||
}
|
||||
|
||||
|
|
@ -82,9 +75,6 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
|
|||
if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectMembershipStmt, err = db.Prepare(selectMembershipSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectMembershipCountStmt, err = db.Prepare(selectMembershipCountSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -111,14 +101,6 @@ func (s *membershipsStatements) UpsertMembership(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *membershipsStatements) SelectMembership(
|
||||
ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string,
|
||||
) (eventID string, streamPos, topologyPos types.StreamPosition, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectMembershipStmt)
|
||||
err = stmt.QueryRowContext(ctx, roomID, userID, memberships).Scan(&eventID, &streamPos, &topologyPos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *membershipsStatements) SelectMembershipCount(
|
||||
ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition,
|
||||
) (count int, err error) {
|
||||
|
|
|
|||
|
|
@ -81,6 +81,15 @@ const insertEventSQL = "" +
|
|||
const selectEventsSQL = "" +
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)"
|
||||
|
||||
const selectEventsWithFilterSQL = "" +
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" +
|
||||
" AND ( $2::text[] IS NULL OR sender = ANY($2) )" +
|
||||
" AND ( $3::text[] IS NULL OR NOT(sender = ANY($3)) )" +
|
||||
" AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" +
|
||||
" AND ( $5::text[] IS NULL OR NOT(type LIKE ANY($5)) )" +
|
||||
" AND ( $6::bool IS NULL OR contains_url = $6 )" +
|
||||
" LIMIT $7"
|
||||
|
||||
const selectRecentEventsSQL = "" +
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
|
||||
" WHERE room_id = $1 AND id > $2 AND id <= $3" +
|
||||
|
|
@ -153,6 +162,7 @@ const selectContextAfterEventSQL = "" +
|
|||
type outputRoomEventsStatements struct {
|
||||
insertEventStmt *sql.Stmt
|
||||
selectEventsStmt *sql.Stmt
|
||||
selectEventsWitFilterStmt *sql.Stmt
|
||||
selectMaxEventIDStmt *sql.Stmt
|
||||
selectRecentEventsStmt *sql.Stmt
|
||||
selectRecentEventsForSyncStmt *sql.Stmt
|
||||
|
|
@ -174,6 +184,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
|
|||
return s, sqlutil.StatementList{
|
||||
{&s.insertEventStmt, insertEventSQL},
|
||||
{&s.selectEventsStmt, selectEventsSQL},
|
||||
{&s.selectEventsWitFilterStmt, selectEventsWithFilterSQL},
|
||||
{&s.selectMaxEventIDStmt, selectMaxEventIDSQL},
|
||||
{&s.selectRecentEventsStmt, selectRecentEventsSQL},
|
||||
{&s.selectRecentEventsForSyncStmt, selectRecentEventsForSyncSQL},
|
||||
|
|
@ -204,11 +215,11 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
|
|||
stateFilter *gomatrixserverlib.StateFilter, roomIDs []string,
|
||||
) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt)
|
||||
|
||||
senders, notSenders := getSendersStateFilterFilter(stateFilter)
|
||||
rows, err := stmt.QueryContext(
|
||||
ctx, r.Low(), r.High(), pq.StringArray(roomIDs),
|
||||
pq.StringArray(stateFilter.Senders),
|
||||
pq.StringArray(stateFilter.NotSenders),
|
||||
pq.StringArray(senders),
|
||||
pq.StringArray(notSenders),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)),
|
||||
stateFilter.ContainsURL,
|
||||
|
|
@ -310,7 +321,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
|
|||
// Parse content as JSON and search for an "url" key
|
||||
containsURL := false
|
||||
var content map[string]interface{}
|
||||
if json.Unmarshal(event.Content(), &content) != nil {
|
||||
if json.Unmarshal(event.Content(), &content) == nil {
|
||||
// Set containsURL to true if url is present
|
||||
_, containsURL = content["url"]
|
||||
}
|
||||
|
|
@ -353,10 +364,11 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
|
|||
} else {
|
||||
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt)
|
||||
}
|
||||
senders, notSenders := getSendersRoomEventFilter(eventFilter)
|
||||
rows, err := stmt.QueryContext(
|
||||
ctx, roomID, r.Low(), r.High(),
|
||||
pq.StringArray(eventFilter.Senders),
|
||||
pq.StringArray(eventFilter.NotSenders),
|
||||
pq.StringArray(senders),
|
||||
pq.StringArray(notSenders),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)),
|
||||
eventFilter.Limit+1,
|
||||
|
|
@ -398,11 +410,12 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
|
|||
ctx context.Context, txn *sql.Tx,
|
||||
roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
|
||||
) ([]types.StreamEvent, error) {
|
||||
senders, notSenders := getSendersRoomEventFilter(eventFilter)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectEarlyEventsStmt)
|
||||
rows, err := stmt.QueryContext(
|
||||
ctx, roomID, r.Low(), r.High(),
|
||||
pq.StringArray(eventFilter.Senders),
|
||||
pq.StringArray(eventFilter.NotSenders),
|
||||
pq.StringArray(senders),
|
||||
pq.StringArray(notSenders),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)),
|
||||
eventFilter.Limit,
|
||||
|
|
@ -427,15 +440,52 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
|
|||
// selectEvents returns the events for the given event IDs. If an event is
|
||||
// missing from the database, it will be omitted.
|
||||
func (s *outputRoomEventsStatements) SelectEvents(
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool,
|
||||
) ([]types.StreamEvent, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectEventsStmt)
|
||||
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||
var (
|
||||
stmt *sql.Stmt
|
||||
rows *sql.Rows
|
||||
err error
|
||||
)
|
||||
if filter == nil {
|
||||
stmt = sqlutil.TxStmt(txn, s.selectEventsStmt)
|
||||
rows, err = stmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||
} else {
|
||||
senders, notSenders := getSendersRoomEventFilter(filter)
|
||||
stmt = sqlutil.TxStmt(txn, s.selectEventsWitFilterStmt)
|
||||
rows, err = stmt.QueryContext(ctx,
|
||||
pq.StringArray(eventIDs),
|
||||
pq.StringArray(senders),
|
||||
pq.StringArray(notSenders),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
|
||||
filter.ContainsURL,
|
||||
filter.Limit,
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
|
||||
return rowsToStreamEvents(rows)
|
||||
streamEvents, err := rowsToStreamEvents(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if preserveOrder {
|
||||
eventMap := make(map[string]types.StreamEvent)
|
||||
for _, ev := range streamEvents {
|
||||
eventMap[ev.EventID()] = ev
|
||||
}
|
||||
var returnEvents []types.StreamEvent
|
||||
for _, eventID := range eventIDs {
|
||||
ev, ok := eventMap[eventID]
|
||||
if ok {
|
||||
returnEvents = append(returnEvents, ev)
|
||||
}
|
||||
}
|
||||
return returnEvents, nil
|
||||
}
|
||||
return streamEvents, nil
|
||||
}
|
||||
|
||||
func (s *outputRoomEventsStatements) DeleteEventsForRoom(
|
||||
|
|
@ -462,10 +512,11 @@ func (s *outputRoomEventsStatements) SelectContextEvent(ctx context.Context, txn
|
|||
func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
|
||||
ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter,
|
||||
) (evts []*gomatrixserverlib.HeaderedEvent, err error) {
|
||||
senders, notSenders := getSendersRoomEventFilter(filter)
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectContextBeforeEventStmt).QueryContext(
|
||||
ctx, roomID, id, filter.Limit,
|
||||
pq.StringArray(filter.Senders),
|
||||
pq.StringArray(filter.NotSenders),
|
||||
pq.StringArray(senders),
|
||||
pq.StringArray(notSenders),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
|
||||
)
|
||||
|
|
@ -494,10 +545,11 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
|
|||
func (s *outputRoomEventsStatements) SelectContextAfterEvent(
|
||||
ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter,
|
||||
) (lastID int, evts []*gomatrixserverlib.HeaderedEvent, err error) {
|
||||
senders, notSenders := getSendersRoomEventFilter(filter)
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectContextAfterEventStmt).QueryContext(
|
||||
ctx, roomID, id, filter.Limit,
|
||||
pq.StringArray(filter.Senders),
|
||||
pq.StringArray(filter.NotSenders),
|
||||
pq.StringArray(senders),
|
||||
pq.StringArray(notSenders),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -73,9 +73,6 @@ const selectMaxPositionInTopologySQL = "" +
|
|||
"SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" +
|
||||
") ORDER BY stream_position DESC LIMIT 1"
|
||||
|
||||
const deleteTopologyForRoomSQL = "" +
|
||||
"DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1"
|
||||
|
||||
const selectStreamToTopologicalPositionAscSQL = "" +
|
||||
"SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;"
|
||||
|
||||
|
|
@ -88,7 +85,6 @@ type outputRoomEventsTopologyStatements struct {
|
|||
selectEventIDsInRangeDESCStmt *sql.Stmt
|
||||
selectPositionInTopologyStmt *sql.Stmt
|
||||
selectMaxPositionInTopologyStmt *sql.Stmt
|
||||
deleteTopologyForRoomStmt *sql.Stmt
|
||||
selectStreamToTopologicalPositionAscStmt *sql.Stmt
|
||||
selectStreamToTopologicalPositionDescStmt *sql.Stmt
|
||||
}
|
||||
|
|
@ -114,9 +110,6 @@ func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) {
|
|||
if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteTopologyForRoomStmt, err = db.Prepare(deleteTopologyForRoomSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -148,9 +141,9 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
|
|||
// is requested or not.
|
||||
var stmt *sql.Stmt
|
||||
if chronologicalOrder {
|
||||
stmt = s.selectEventIDsInRangeASCStmt
|
||||
stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeASCStmt)
|
||||
} else {
|
||||
stmt = s.selectEventIDsInRangeDESCStmt
|
||||
stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeDESCStmt)
|
||||
}
|
||||
|
||||
// Query the event IDs.
|
||||
|
|
@ -203,10 +196,3 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
|
|||
err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID)
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -150,7 +150,7 @@ func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, stre
|
|||
// Returns an error if there was a problem talking with the database.
|
||||
// Does not include any transaction IDs in the returned events.
|
||||
func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||
streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs)
|
||||
streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, nil, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -312,7 +312,7 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e
|
|||
|
||||
// Check if we have all of the event's previous events. If an event is
|
||||
// missing, add it to the room's backward extremities.
|
||||
prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs())
|
||||
prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), nil, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -429,7 +429,8 @@ func (d *Database) updateRoomState(
|
|||
func (d *Database) GetEventsInTopologicalRange(
|
||||
ctx context.Context,
|
||||
from, to *types.TopologyToken,
|
||||
roomID string, limit int,
|
||||
roomID string,
|
||||
filter *gomatrixserverlib.RoomEventFilter,
|
||||
backwardOrdering bool,
|
||||
) (events []types.StreamEvent, err error) {
|
||||
var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition
|
||||
|
|
@ -450,14 +451,14 @@ func (d *Database) GetEventsInTopologicalRange(
|
|||
// Select the event IDs from the defined range.
|
||||
var eIDs []string
|
||||
eIDs, err = d.Topology.SelectEventIDsInRange(
|
||||
ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, limit, !backwardOrdering,
|
||||
ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Retrieve the events' contents using their IDs.
|
||||
events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs)
|
||||
events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs, filter, true)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -619,7 +620,7 @@ func (d *Database) fetchMissingStateEvents(
|
|||
) ([]types.StreamEvent, error) {
|
||||
// Fetch from the events table first so we pick up the stream ID for the
|
||||
// event.
|
||||
events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs)
|
||||
events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, nil, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -687,6 +688,9 @@ func (d *Database) GetStateDeltas(
|
|||
// user has ever interacted with — joined to, kicked/banned from, left.
|
||||
memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil, nil
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
|
|
@ -704,17 +708,23 @@ func (d *Database) GetStateDeltas(
|
|||
// get all the state events ever (i.e. for all available rooms) between these two positions
|
||||
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil, nil
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil, nil
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// find out which rooms this user is peeking, if any.
|
||||
// We do this before joins so any peeks get overwritten
|
||||
peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r)
|
||||
if err != nil {
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
|
|
@ -725,6 +735,9 @@ func (d *Database) GetStateDeltas(
|
|||
var s []types.StreamEvent
|
||||
s, err = d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
continue
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
state[peek.RoomID] = s
|
||||
|
|
@ -752,6 +765,9 @@ func (d *Database) GetStateDeltas(
|
|||
var s []types.StreamEvent
|
||||
s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
continue
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
state[roomID] = s
|
||||
|
|
@ -802,6 +818,9 @@ func (d *Database) GetStateDeltasForFullStateSync(
|
|||
// user has ever interacted with — joined to, kicked/banned from, left.
|
||||
memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil, nil
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
|
|
@ -818,7 +837,7 @@ func (d *Database) GetStateDeltasForFullStateSync(
|
|||
deltas := make(map[string]types.StateDelta)
|
||||
|
||||
peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r)
|
||||
if err != nil {
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
|
|
@ -827,6 +846,9 @@ func (d *Database) GetStateDeltasForFullStateSync(
|
|||
if !peek.Deleted {
|
||||
s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter)
|
||||
if stateErr != nil {
|
||||
if stateErr == sql.ErrNoRows {
|
||||
continue
|
||||
}
|
||||
return nil, nil, stateErr
|
||||
}
|
||||
deltas[peek.RoomID] = types.StateDelta{
|
||||
|
|
@ -840,10 +862,16 @@ func (d *Database) GetStateDeltasForFullStateSync(
|
|||
// Get all the state events ever between these two positions
|
||||
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil, nil
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil, nil
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
|
|
@ -868,6 +896,9 @@ func (d *Database) GetStateDeltasForFullStateSync(
|
|||
for _, joinedRoomID := range joinedRoomIDs {
|
||||
s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilter)
|
||||
if stateErr != nil {
|
||||
if stateErr == sql.ErrNoRows {
|
||||
continue
|
||||
}
|
||||
return nil, nil, stateErr
|
||||
}
|
||||
deltas[joinedRoomID] = types.StateDelta{
|
||||
|
|
|
|||
|
|
@ -41,23 +41,23 @@ const insertAccountDataSQL = "" +
|
|||
" ON CONFLICT (user_id, room_id, type) DO UPDATE" +
|
||||
" SET id = $5"
|
||||
|
||||
// further parameters are added by prepareWithFilters
|
||||
const selectAccountDataInRangeSQL = "" +
|
||||
"SELECT room_id, type FROM syncapi_account_data_type" +
|
||||
" WHERE user_id = $1 AND id > $2 AND id <= $3" +
|
||||
" ORDER BY id ASC"
|
||||
" WHERE user_id = $1 AND id > $2 AND id <= $3"
|
||||
|
||||
const selectMaxAccountDataIDSQL = "" +
|
||||
"SELECT MAX(id) FROM syncapi_account_data_type"
|
||||
|
||||
type accountDataStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
insertAccountDataStmt *sql.Stmt
|
||||
selectMaxAccountDataIDStmt *sql.Stmt
|
||||
selectAccountDataInRangeStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) {
|
||||
func NewSqliteAccountDataTable(db *sql.DB, streamID *StreamIDStatements) (tables.AccountData, error) {
|
||||
s := &accountDataStatements{
|
||||
db: db,
|
||||
streamIDStatements: streamID,
|
||||
|
|
@ -94,18 +94,24 @@ func (s *accountDataStatements) SelectAccountDataInRange(
|
|||
ctx context.Context,
|
||||
userID string,
|
||||
r types.Range,
|
||||
accountDataFilterPart *gomatrixserverlib.EventFilter,
|
||||
filter *gomatrixserverlib.EventFilter,
|
||||
) (data map[string][]string, err error) {
|
||||
data = make(map[string][]string)
|
||||
stmt, params, err := prepareWithFilters(
|
||||
s.db, nil, selectAccountDataInRangeSQL,
|
||||
[]interface{}{
|
||||
userID, r.Low(), r.High(),
|
||||
},
|
||||
filter.Senders, filter.NotSenders,
|
||||
filter.Types, filter.NotTypes,
|
||||
[]string{}, nil, filter.Limit, FilterOrderAsc)
|
||||
|
||||
rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High())
|
||||
rows, err := stmt.QueryContext(ctx, params...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed")
|
||||
|
||||
var entries int
|
||||
|
||||
for rows.Next() {
|
||||
var dataType string
|
||||
var roomID string
|
||||
|
|
@ -114,31 +120,11 @@ func (s *accountDataStatements) SelectAccountDataInRange(
|
|||
return
|
||||
}
|
||||
|
||||
// check if we should add this by looking at the filter.
|
||||
// It would be nice if we could do this in SQL-land, but the mix of variadic
|
||||
// and positional parameters makes the query annoyingly hard to do, it's easier
|
||||
// and clearer to do it in Go-land. If there are no filters for [not]types then
|
||||
// this gets skipped.
|
||||
for _, includeType := range accountDataFilterPart.Types {
|
||||
if includeType != dataType { // TODO: wildcard support
|
||||
continue
|
||||
}
|
||||
}
|
||||
for _, excludeType := range accountDataFilterPart.NotTypes {
|
||||
if excludeType == dataType { // TODO: wildcard support
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if len(data[roomID]) > 0 {
|
||||
data[roomID] = append(data[roomID], dataType)
|
||||
} else {
|
||||
data[roomID] = []string{dataType}
|
||||
}
|
||||
entries++
|
||||
if entries >= accountDataFilterPart.Limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return data, nil
|
||||
|
|
|
|||
|
|
@ -47,15 +47,11 @@ const selectBackwardExtremitiesForRoomSQL = "" +
|
|||
const deleteBackwardExtremitySQL = "" +
|
||||
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
|
||||
|
||||
const deleteBackwardExtremitiesForRoomSQL = "" +
|
||||
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1"
|
||||
|
||||
type backwardExtremitiesStatements struct {
|
||||
db *sql.DB
|
||||
insertBackwardExtremityStmt *sql.Stmt
|
||||
selectBackwardExtremitiesForRoomStmt *sql.Stmt
|
||||
deleteBackwardExtremityStmt *sql.Stmt
|
||||
deleteBackwardExtremitiesForRoomStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
|
||||
|
|
@ -75,9 +71,6 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities
|
|||
if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
|
|
@ -116,10 +109,3 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
|
|||
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID)
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ const selectEventsWithEventIDsSQL = "" +
|
|||
|
||||
type currentRoomStateStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
upsertRoomStateStmt *sql.Stmt
|
||||
deleteRoomStateByEventIDStmt *sql.Stmt
|
||||
deleteRoomStateForRoomStmt *sql.Stmt
|
||||
|
|
@ -100,7 +100,7 @@ type currentRoomStateStatements struct {
|
|||
selectStateEventStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
|
||||
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
|
||||
s := ¤tRoomStateStatements{
|
||||
db: db,
|
||||
streamIDStatements: streamID,
|
||||
|
|
@ -220,7 +220,7 @@ func (s *currentRoomStateStatements) SelectCurrentState(
|
|||
},
|
||||
stateFilter.Senders, stateFilter.NotSenders,
|
||||
stateFilter.Types, stateFilter.NotTypes,
|
||||
excludeEventIDs, stateFilter.Limit, FilterOrderNone,
|
||||
excludeEventIDs, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderNone,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
|
||||
|
|
|
|||
|
|
@ -25,33 +25,52 @@ const (
|
|||
// parts.
|
||||
func prepareWithFilters(
|
||||
db *sql.DB, txn *sql.Tx, query string, params []interface{},
|
||||
senders, notsenders, types, nottypes []string, excludeEventIDs []string,
|
||||
limit int, order FilterOrder,
|
||||
senders, notsenders, types, nottypes *[]string, excludeEventIDs []string,
|
||||
containsURL *bool, limit int, order FilterOrder,
|
||||
) (*sql.Stmt, []interface{}, error) {
|
||||
offset := len(params)
|
||||
if count := len(senders); count > 0 {
|
||||
if senders != nil {
|
||||
if count := len(*senders); count > 0 {
|
||||
query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range senders {
|
||||
for _, v := range *senders {
|
||||
params, offset = append(params, v), offset+1
|
||||
}
|
||||
} else {
|
||||
query += ` AND sender = ""`
|
||||
}
|
||||
if count := len(notsenders); count > 0 {
|
||||
}
|
||||
if notsenders != nil {
|
||||
if count := len(*notsenders); count > 0 {
|
||||
query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range notsenders {
|
||||
for _, v := range *notsenders {
|
||||
params, offset = append(params, v), offset+1
|
||||
}
|
||||
} else {
|
||||
query += ` AND sender NOT = ""`
|
||||
}
|
||||
if count := len(types); count > 0 {
|
||||
}
|
||||
if types != nil {
|
||||
if count := len(*types); count > 0 {
|
||||
query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range types {
|
||||
for _, v := range *types {
|
||||
params, offset = append(params, v), offset+1
|
||||
}
|
||||
} else {
|
||||
query += ` AND type = ""`
|
||||
}
|
||||
if count := len(nottypes); count > 0 {
|
||||
}
|
||||
if nottypes != nil {
|
||||
if count := len(*nottypes); count > 0 {
|
||||
query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range nottypes {
|
||||
for _, v := range *nottypes {
|
||||
params, offset = append(params, v), offset+1
|
||||
}
|
||||
} else {
|
||||
query += ` AND type NOT = ""`
|
||||
}
|
||||
}
|
||||
if containsURL != nil {
|
||||
query += fmt.Sprintf(" AND contains_url = %v", *containsURL)
|
||||
}
|
||||
if count := len(excludeEventIDs); count > 0 {
|
||||
query += " AND event_id NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
|
|
|
|||
|
|
@ -59,14 +59,14 @@ const selectMaxInviteIDSQL = "" +
|
|||
|
||||
type inviteEventsStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
insertInviteEventStmt *sql.Stmt
|
||||
selectInviteEventsInRangeStmt *sql.Stmt
|
||||
deleteInviteEventStmt *sql.Stmt
|
||||
selectMaxInviteIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) {
|
||||
func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Invites, error) {
|
||||
s := &inviteEventsStatements{
|
||||
db: db,
|
||||
streamIDStatements: streamID,
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
|
|
@ -57,12 +56,6 @@ const upsertMembershipSQL = "" +
|
|||
" ON CONFLICT (room_id, user_id, membership)" +
|
||||
" DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6"
|
||||
|
||||
const selectMembershipSQL = "" +
|
||||
"SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" +
|
||||
" WHERE room_id = $1 AND user_id = $2 AND membership IN ($3)" +
|
||||
" ORDER BY stream_pos DESC" +
|
||||
" LIMIT 1"
|
||||
|
||||
const selectMembershipCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM (" +
|
||||
" SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" +
|
||||
|
|
@ -111,22 +104,6 @@ func (s *membershipsStatements) UpsertMembership(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *membershipsStatements) SelectMembership(
|
||||
ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string,
|
||||
) (eventID string, streamPos, topologyPos types.StreamPosition, err error) {
|
||||
params := []interface{}{roomID, userID}
|
||||
for _, membership := range memberships {
|
||||
params = append(params, membership)
|
||||
}
|
||||
orig := strings.Replace(selectMembershipSQL, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1)
|
||||
stmt, err := s.db.Prepare(orig)
|
||||
if err != nil {
|
||||
return "", 0, 0, err
|
||||
}
|
||||
err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *membershipsStatements) SelectMembershipCount(
|
||||
ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition,
|
||||
) (count int, err error) {
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ const insertEventSQL = "" +
|
|||
"ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)"
|
||||
|
||||
const selectEventsSQL = "" +
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1"
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id IN ($1)"
|
||||
|
||||
const selectRecentEventsSQL = "" +
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
|
||||
|
|
@ -111,9 +111,8 @@ const selectContextAfterEventSQL = "" +
|
|||
|
||||
type outputRoomEventsStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
insertEventStmt *sql.Stmt
|
||||
selectEventsStmt *sql.Stmt
|
||||
selectMaxEventIDStmt *sql.Stmt
|
||||
updateEventJSONStmt *sql.Stmt
|
||||
deleteEventsForRoomStmt *sql.Stmt
|
||||
|
|
@ -122,7 +121,7 @@ type outputRoomEventsStatements struct {
|
|||
selectContextAfterEventStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
|
||||
func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Events, error) {
|
||||
s := &outputRoomEventsStatements{
|
||||
db: db,
|
||||
streamIDStatements: streamID,
|
||||
|
|
@ -133,7 +132,6 @@ func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Even
|
|||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertEventStmt, insertEventSQL},
|
||||
{&s.selectEventsStmt, selectEventsSQL},
|
||||
{&s.selectMaxEventIDStmt, selectMaxEventIDSQL},
|
||||
{&s.updateEventJSONStmt, updateEventJSONSQL},
|
||||
{&s.deleteEventsForRoomStmt, deleteEventsForRoomSQL},
|
||||
|
|
@ -170,7 +168,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
|
|||
s.db, txn, stmtSQL, inputParams,
|
||||
stateFilter.Senders, stateFilter.NotSenders,
|
||||
stateFilter.Types, stateFilter.NotTypes,
|
||||
nil, stateFilter.Limit, FilterOrderAsc,
|
||||
nil, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderAsc,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err)
|
||||
|
|
@ -279,7 +277,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
|
|||
// Parse content as JSON and search for an "url" key
|
||||
containsURL := false
|
||||
var content map[string]interface{}
|
||||
if json.Unmarshal(event.Content(), &content) != nil {
|
||||
if json.Unmarshal(event.Content(), &content) == nil {
|
||||
// Set containsURL to true if url is present
|
||||
_, containsURL = content["url"]
|
||||
}
|
||||
|
|
@ -347,7 +345,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
|
|||
},
|
||||
eventFilter.Senders, eventFilter.NotSenders,
|
||||
eventFilter.Types, eventFilter.NotTypes,
|
||||
nil, eventFilter.Limit+1, FilterOrderDesc,
|
||||
nil, eventFilter.ContainsURL, eventFilter.Limit+1, FilterOrderDesc,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err)
|
||||
|
|
@ -395,7 +393,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
|
|||
},
|
||||
eventFilter.Senders, eventFilter.NotSenders,
|
||||
eventFilter.Types, eventFilter.NotTypes,
|
||||
nil, eventFilter.Limit, FilterOrderAsc,
|
||||
nil, eventFilter.ContainsURL, eventFilter.Limit, FilterOrderAsc,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
|
||||
|
|
@ -421,21 +419,50 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
|
|||
// selectEvents returns the events for the given event IDs. If an event is
|
||||
// missing from the database, it will be omitted.
|
||||
func (s *outputRoomEventsStatements) SelectEvents(
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool,
|
||||
) ([]types.StreamEvent, error) {
|
||||
var returnEvents []types.StreamEvent
|
||||
stmt := sqlutil.TxStmt(txn, s.selectEventsStmt)
|
||||
for _, eventID := range eventIDs {
|
||||
rows, err := stmt.QueryContext(ctx, eventID)
|
||||
iEventIDs := make([]interface{}, len(eventIDs))
|
||||
for i := range eventIDs {
|
||||
iEventIDs[i] = eventIDs[i]
|
||||
}
|
||||
selectSQL := strings.Replace(selectEventsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1)
|
||||
|
||||
if filter == nil {
|
||||
filter = &gomatrixserverlib.RoomEventFilter{Limit: 20}
|
||||
}
|
||||
stmt, params, err := prepareWithFilters(
|
||||
s.db, txn, selectSQL, iEventIDs,
|
||||
filter.Senders, filter.NotSenders,
|
||||
filter.Types, filter.NotTypes,
|
||||
nil, filter.ContainsURL, filter.Limit, FilterOrderAsc,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if streamEvents, err := rowsToStreamEvents(rows); err == nil {
|
||||
returnEvents = append(returnEvents, streamEvents...)
|
||||
rows, err := stmt.QueryContext(ctx, params...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
|
||||
streamEvents, err := rowsToStreamEvents(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if preserveOrder {
|
||||
var returnEvents []types.StreamEvent
|
||||
eventMap := make(map[string]types.StreamEvent)
|
||||
for _, ev := range streamEvents {
|
||||
eventMap[ev.EventID()] = ev
|
||||
}
|
||||
for _, eventID := range eventIDs {
|
||||
ev, ok := eventMap[eventID]
|
||||
if ok {
|
||||
returnEvents = append(returnEvents, ev)
|
||||
}
|
||||
internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
|
||||
}
|
||||
return returnEvents, nil
|
||||
}
|
||||
return streamEvents, nil
|
||||
}
|
||||
|
||||
func (s *outputRoomEventsStatements) DeleteEventsForRoom(
|
||||
|
|
@ -507,7 +534,7 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
|
|||
},
|
||||
filter.Senders, filter.NotSenders,
|
||||
filter.Types, filter.NotTypes,
|
||||
nil, filter.Limit, FilterOrderDesc,
|
||||
nil, filter.ContainsURL, filter.Limit, FilterOrderDesc,
|
||||
)
|
||||
|
||||
rows, err := stmt.QueryContext(ctx, params...)
|
||||
|
|
@ -543,7 +570,7 @@ func (s *outputRoomEventsStatements) SelectContextAfterEvent(
|
|||
},
|
||||
filter.Senders, filter.NotSenders,
|
||||
filter.Types, filter.NotTypes,
|
||||
nil, filter.Limit, FilterOrderAsc,
|
||||
nil, filter.ContainsURL, filter.Limit, FilterOrderAsc,
|
||||
)
|
||||
|
||||
rows, err := stmt.QueryContext(ctx, params...)
|
||||
|
|
|
|||
|
|
@ -78,7 +78,6 @@ type outputRoomEventsTopologyStatements struct {
|
|||
selectEventIDsInRangeDESCStmt *sql.Stmt
|
||||
selectPositionInTopologyStmt *sql.Stmt
|
||||
selectMaxPositionInTopologyStmt *sql.Stmt
|
||||
deleteTopologyForRoomStmt *sql.Stmt
|
||||
selectStreamToTopologicalPositionAscStmt *sql.Stmt
|
||||
selectStreamToTopologicalPositionDescStmt *sql.Stmt
|
||||
}
|
||||
|
|
@ -191,10 +190,3 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
|
|||
err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID)
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ const selectMaxPeekIDSQL = "" +
|
|||
|
||||
type peekStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
insertPeekStmt *sql.Stmt
|
||||
deletePeekStmt *sql.Stmt
|
||||
deletePeeksStmt *sql.Stmt
|
||||
|
|
@ -75,7 +75,7 @@ type peekStatements struct {
|
|||
selectMaxPeekIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks, error) {
|
||||
func NewSqlitePeeksTable(db *sql.DB, streamID *StreamIDStatements) (tables.Peeks, error) {
|
||||
_, err := db.Exec(peeksSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ const selectPresenceAfter = "" +
|
|||
|
||||
type presenceStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
upsertPresenceStmt *sql.Stmt
|
||||
upsertPresenceFromSyncStmt *sql.Stmt
|
||||
selectPresenceForUsersStmt *sql.Stmt
|
||||
|
|
@ -83,7 +83,7 @@ type presenceStatements struct {
|
|||
selectPresenceAfterStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqlitePresenceTable(db *sql.DB, streamID *streamIDStatements) (*presenceStatements, error) {
|
||||
func NewSqlitePresenceTable(db *sql.DB, streamID *StreamIDStatements) (*presenceStatements, error) {
|
||||
_, err := db.Exec(presenceSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -59,13 +59,13 @@ const selectMaxReceiptIDSQL = "" +
|
|||
|
||||
type receiptStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
upsertReceipt *sql.Stmt
|
||||
selectRoomReceipts *sql.Stmt
|
||||
selectMaxReceiptID *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Receipts, error) {
|
||||
func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Receipts, error) {
|
||||
_, err := db.Exec(receiptsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -32,12 +32,12 @@ const increaseStreamIDStmt = "" +
|
|||
"UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" +
|
||||
" RETURNING stream_id"
|
||||
|
||||
type streamIDStatements struct {
|
||||
type StreamIDStatements struct {
|
||||
db *sql.DB
|
||||
increaseStreamIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
|
||||
func (s *StreamIDStatements) Prepare(db *sql.DB) (err error) {
|
||||
s.db = db
|
||||
_, err = db.Exec(streamIDTableSchema)
|
||||
if err != nil {
|
||||
|
|
@ -49,31 +49,31 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
func (s *StreamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||
err = increaseStmt.QueryRowContext(ctx, "global").Scan(&pos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
func (s *StreamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||
err = increaseStmt.QueryRowContext(ctx, "receipt").Scan(&pos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
func (s *StreamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||
err = increaseStmt.QueryRowContext(ctx, "invite").Scan(&pos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
func (s *StreamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||
err = increaseStmt.QueryRowContext(ctx, "accountdata").Scan(&pos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
func (s *StreamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||
err = increaseStmt.QueryRowContext(ctx, "presence").Scan(&pos)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ type SyncServerDatasource struct {
|
|||
shared.Database
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
streamID streamIDStatements
|
||||
streamID StreamIDStatements
|
||||
}
|
||||
|
||||
// NewDatabase creates a new sync server database
|
||||
|
|
@ -49,7 +49,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
|
|||
}
|
||||
|
||||
func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) {
|
||||
if err = d.streamID.prepare(d.db); err != nil {
|
||||
if err = d.streamID.Prepare(d.db); err != nil {
|
||||
return err
|
||||
}
|
||||
accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package storage_test
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
|
|
@ -38,7 +39,7 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
|
|||
if err != nil {
|
||||
t.Fatalf("WriteEvent failed: %s", err)
|
||||
}
|
||||
fmt.Println("Event ID", ev.EventID(), " spos=", pos, "depth=", ev.Depth())
|
||||
t.Logf("Event ID %s spos=%v depth=%v", ev.EventID(), pos, ev.Depth())
|
||||
positions = append(positions, pos)
|
||||
}
|
||||
return
|
||||
|
|
@ -46,7 +47,6 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
|
|||
|
||||
func TestWriteEvents(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
t.Parallel()
|
||||
alice := test.NewUser()
|
||||
r := test.NewRoom(t, alice)
|
||||
db, close := MustCreateDatabase(t, dbType)
|
||||
|
|
@ -61,12 +61,18 @@ func TestRecentEventsPDU(t *testing.T) {
|
|||
db, close := MustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
alice := test.NewUser()
|
||||
var filter gomatrixserverlib.RoomEventFilter
|
||||
filter.Limit = 100
|
||||
// dummy room to make sure SQL queries are filtering on room ID
|
||||
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
|
||||
|
||||
// actual test room
|
||||
r := test.NewRoom(t, alice)
|
||||
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hi"})
|
||||
events := r.Events()
|
||||
positions := MustWriteEvents(t, db, events)
|
||||
|
||||
// dummy room to make sure SQL queries are filtering on room ID
|
||||
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
|
||||
|
||||
latest, err := db.MaxStreamPositionForPDUs(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err)
|
||||
|
|
@ -76,69 +82,63 @@ func TestRecentEventsPDU(t *testing.T) {
|
|||
Name string
|
||||
From types.StreamPosition
|
||||
To types.StreamPosition
|
||||
Limit int
|
||||
ReverseOrder bool
|
||||
WantEvents []*gomatrixserverlib.HeaderedEvent
|
||||
WantLimited bool
|
||||
}{
|
||||
// The purpose of this test is to make sure that incremental syncs are including up to the latest events.
|
||||
// It's a basic sanity test that sync works. It creates a `since` token that is on the penultimate event.
|
||||
// It's a basic sanity test that sync works. It creates a streaming position that is on the penultimate event.
|
||||
// It makes sure the response includes the final event.
|
||||
{
|
||||
Name: "IncrementalSync penultimate",
|
||||
Name: "penultimate",
|
||||
From: positions[len(positions)-2], // pretend we are at the penultimate event
|
||||
To: latest,
|
||||
Limit: 100,
|
||||
WantEvents: events[len(events)-1:],
|
||||
WantLimited: false,
|
||||
},
|
||||
/*
|
||||
// The purpose of this test is to check that passing a `numRecentEventsPerRoom` correctly limits the
|
||||
// number of returned events. This is critical for big rooms hence the test here.
|
||||
// The purpose of this test is to check that limits can be applied and work.
|
||||
// This is critical for big rooms hence the test here.
|
||||
{
|
||||
Name: "IncrementalSync limited",
|
||||
DoSync: func() (*types.Response, error) {
|
||||
from := types.StreamingToken{ // pretend we are 10 events behind
|
||||
PDUPosition: positions[len(positions)-11],
|
||||
}
|
||||
res := types.NewResponse()
|
||||
// limit is set to 5
|
||||
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
|
||||
Name: "limited",
|
||||
From: 0,
|
||||
To: latest,
|
||||
Limit: 1,
|
||||
WantEvents: events[len(events)-1:],
|
||||
WantLimited: true,
|
||||
},
|
||||
// want the last 5 events, NOT the last 10.
|
||||
WantTimeline: events[len(events)-5:],
|
||||
},
|
||||
// The purpose of this test is to check that CompleteSync returns all the current state as well as
|
||||
// honouring the `numRecentEventsPerRoom` value
|
||||
// The purpose of this test is to check that we can return every event with a high
|
||||
// enough limit
|
||||
{
|
||||
Name: "CompleteSync limited",
|
||||
DoSync: func() (*types.Response, error) {
|
||||
res := types.NewResponse()
|
||||
// limit set to 5
|
||||
return db.CompleteSync(ctx, res, testUserDeviceA, 5)
|
||||
Name: "large limited",
|
||||
From: 0,
|
||||
To: latest,
|
||||
Limit: 100,
|
||||
WantEvents: events,
|
||||
WantLimited: false,
|
||||
},
|
||||
// want the last 5 events
|
||||
WantTimeline: events[len(events)-5:],
|
||||
// want all state for the room
|
||||
WantState: state,
|
||||
},
|
||||
// The purpose of this test is to check that CompleteSync can return everything with a high enough
|
||||
// `numRecentEventsPerRoom`.
|
||||
// The purpose of this test is to check that we can return events in reverse order
|
||||
{
|
||||
Name: "CompleteSync",
|
||||
DoSync: func() (*types.Response, error) {
|
||||
res := types.NewResponse()
|
||||
return db.CompleteSync(ctx, res, testUserDeviceA, len(events)+1)
|
||||
Name: "reverse",
|
||||
From: positions[len(positions)-3], // 2 events back
|
||||
To: latest,
|
||||
Limit: 100,
|
||||
ReverseOrder: true,
|
||||
WantEvents: test.Reversed(events[len(events)-2:]),
|
||||
WantLimited: false,
|
||||
},
|
||||
WantTimeline: events,
|
||||
// We want no state at all as that field in /sync is the delta between the token (beginning of time)
|
||||
// and the START of the timeline.
|
||||
}, */
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
for i := range testCases {
|
||||
tc := testCases[i]
|
||||
t.Run(tc.Name, func(st *testing.T) {
|
||||
var filter gomatrixserverlib.RoomEventFilter
|
||||
filter.Limit = tc.Limit
|
||||
gotEvents, limited, err := db.RecentEvents(ctx, r.ID, types.Range{
|
||||
From: tc.From,
|
||||
To: tc.To,
|
||||
}, &filter, true, true)
|
||||
}, &filter, !tc.ReverseOrder, true)
|
||||
if err != nil {
|
||||
st.Fatalf("failed to do sync: %s", err)
|
||||
}
|
||||
|
|
@ -148,100 +148,49 @@ func TestRecentEventsPDU(t *testing.T) {
|
|||
if len(gotEvents) != len(tc.WantEvents) {
|
||||
st.Errorf("got %d events, want %d", len(gotEvents), len(tc.WantEvents))
|
||||
}
|
||||
for j := range gotEvents {
|
||||
if !reflect.DeepEqual(gotEvents[j].JSON(), tc.WantEvents[j].JSON()) {
|
||||
st.Errorf("event %d got %s want %s", j, string(gotEvents[j].JSON()), string(tc.WantEvents[j].JSON()))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/*
|
||||
func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := MustCreateDatabase(t)
|
||||
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
|
||||
positions := MustWriteEvents(t, db, events)
|
||||
latest, err := db.SyncPosition(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get SyncPosition: %s", err)
|
||||
}
|
||||
from := types.StreamingToken{
|
||||
PDUPosition: positions[len(positions)-2],
|
||||
}
|
||||
|
||||
res := types.NewResponse()
|
||||
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to IncrementalSync with latest token")
|
||||
}
|
||||
roomRes, ok := res.Rooms.Join[testRoomID]
|
||||
if !ok {
|
||||
t.Fatalf("IncrementalSync response missing room %s - response: %+v", testRoomID, res)
|
||||
}
|
||||
// returns the last event "Message 10"
|
||||
assertEventsEqual(t, "IncrementalSync Timeline", false, roomRes.Timeline.Events, reversed(events[len(events)-1:]))
|
||||
|
||||
prev := roomRes.Timeline.PrevBatch.String()
|
||||
if prev == "" {
|
||||
t.Fatalf("IncrementalSync expected prev_batch token")
|
||||
}
|
||||
prevBatchToken, err := types.NewTopologyTokenFromString(prev)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to NewTopologyTokenFromString : %s", err)
|
||||
}
|
||||
// backpaginate 5 messages starting at the latest position.
|
||||
// head towards the beginning of time
|
||||
to := types.TopologyToken{}
|
||||
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &prevBatchToken, &to, testRoomID, 5, true)
|
||||
if err != nil {
|
||||
t.Fatalf("GetEventsInRange returned an error: %s", err)
|
||||
}
|
||||
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
|
||||
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-6:len(events)-1]))
|
||||
}
|
||||
|
||||
// The purpose of this test is to ensure that backfill does indeed go backwards, using a stream token.
|
||||
func TestGetEventsInRangeWithStreamToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := MustCreateDatabase(t)
|
||||
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
|
||||
MustWriteEvents(t, db, events)
|
||||
latest, err := db.SyncPosition(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get SyncPosition: %s", err)
|
||||
}
|
||||
// head towards the beginning of time
|
||||
to := types.StreamingToken{}
|
||||
|
||||
// backpaginate 5 messages starting at the latest position.
|
||||
paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true)
|
||||
if err != nil {
|
||||
t.Fatalf("GetEventsInRange returned an error: %s", err)
|
||||
}
|
||||
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
|
||||
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:]))
|
||||
}
|
||||
|
||||
// The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token
|
||||
func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := MustCreateDatabase(t)
|
||||
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
|
||||
MustWriteEvents(t, db, events)
|
||||
from, err := db.MaxTopologicalPosition(ctx, testRoomID)
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := MustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
alice := test.NewUser()
|
||||
r := test.NewRoom(t, alice)
|
||||
for i := 0; i < 10; i++ {
|
||||
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)})
|
||||
}
|
||||
events := r.Events()
|
||||
_ = MustWriteEvents(t, db, events)
|
||||
|
||||
from, err := db.MaxTopologicalPosition(ctx, r.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
|
||||
}
|
||||
t.Logf("max topo pos = %+v", from)
|
||||
// head towards the beginning of time
|
||||
to := types.TopologyToken{}
|
||||
|
||||
// backpaginate 5 messages starting at the latest position.
|
||||
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, testRoomID, 5, true)
|
||||
filter := &gomatrixserverlib.RoomEventFilter{Limit: 5}
|
||||
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true)
|
||||
if err != nil {
|
||||
t.Fatalf("GetEventsInRange returned an error: %s", err)
|
||||
t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
|
||||
}
|
||||
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
|
||||
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:]))
|
||||
gots := db.StreamEventsToEvents(nil, paginatedEvents)
|
||||
test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
|
||||
})
|
||||
}
|
||||
|
||||
/*
|
||||
// The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth.
|
||||
// For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent
|
||||
// will appear FIRST when going backwards. This test creates a DAG like:
|
||||
|
|
@ -651,12 +600,4 @@ func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *typ
|
|||
tok.Decrement()
|
||||
return &tok
|
||||
}
|
||||
|
||||
func reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
|
||||
out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
|
||||
for i := 0; i < len(in); i++ {
|
||||
out[i] = in[len(in)-i-1]
|
||||
}
|
||||
return out
|
||||
}
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ type Events interface {
|
|||
SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
|
||||
// SelectEarlyEvents returns the earliest events in the given room.
|
||||
SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error)
|
||||
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error)
|
||||
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool) ([]types.StreamEvent, error)
|
||||
UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error
|
||||
// DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely.
|
||||
DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
|
||||
|
|
@ -84,8 +84,6 @@ type Topology interface {
|
|||
SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error)
|
||||
// SelectMaxPositionInTopology returns the event which has the highest depth, and if there are multiple, the event with the highest stream position.
|
||||
SelectMaxPositionInTopology(ctx context.Context, txn *sql.Tx, roomID string) (depth types.StreamPosition, spos types.StreamPosition, err error)
|
||||
// DeleteTopologyForRoom removes all topological information for a room. This should only be done when removing the room entirely.
|
||||
DeleteTopologyForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
|
||||
// SelectStreamToTopologicalPosition converts a stream position to a topological position by finding the nearest topological position in the room.
|
||||
SelectStreamToTopologicalPosition(ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, forward bool) (topoPos types.StreamPosition, err error)
|
||||
}
|
||||
|
|
@ -132,8 +130,6 @@ type BackwardsExtremities interface {
|
|||
SelectBackwardExtremitiesForRoom(ctx context.Context, roomID string) (bwExtrems map[string][]string, err error)
|
||||
// DeleteBackwardExtremity removes a backwards extremity for a room, if one existed.
|
||||
DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error)
|
||||
// DeleteBackwardExtremitiesFoorRoomID removes all backward extremities for a room. This should only be done when removing the room entirely.
|
||||
DeleteBackwardExtremitiesForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
|
||||
}
|
||||
|
||||
// SendToDevice tracks send-to-device messages which are sent to individual
|
||||
|
|
@ -173,7 +169,6 @@ type Receipts interface {
|
|||
|
||||
type Memberships interface {
|
||||
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
|
||||
SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string) (eventID string, streamPos, topologyPos types.StreamPosition, err error)
|
||||
SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error)
|
||||
}
|
||||
|
||||
|
|
|
|||
105
syncapi/storage/tables/output_room_events_test.go
Normal file
105
syncapi/storage/tables/output_room_events_test.go
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
package tables_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events, *sql.DB, func()) {
|
||||
t.Helper()
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open db: %s", err)
|
||||
}
|
||||
|
||||
var tab tables.Events
|
||||
switch dbType {
|
||||
case test.DBTypePostgres:
|
||||
tab, err = postgres.NewPostgresEventsTable(db)
|
||||
case test.DBTypeSQLite:
|
||||
var stream sqlite3.StreamIDStatements
|
||||
if err = stream.Prepare(db); err != nil {
|
||||
t.Fatalf("failed to prepare stream stmts: %s", err)
|
||||
}
|
||||
tab, err = sqlite3.NewSqliteEventsTable(db, &stream)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make new table: %s", err)
|
||||
}
|
||||
return tab, db, close
|
||||
}
|
||||
|
||||
func TestOutputRoomEventsTable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
alice := test.NewUser()
|
||||
room := test.NewRoom(t, alice)
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
tab, db, close := newOutputRoomEventsTable(t, dbType)
|
||||
defer close()
|
||||
events := room.Events()
|
||||
err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
|
||||
for _, ev := range events {
|
||||
_, err := tab.InsertEvent(ctx, txn, ev, nil, nil, nil, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to InsertEvent: %s", err)
|
||||
}
|
||||
}
|
||||
// order = 2,0,3,1
|
||||
wantEventIDs := []string{
|
||||
events[2].EventID(), events[0].EventID(), events[3].EventID(), events[1].EventID(),
|
||||
}
|
||||
gotEvents, err := tab.SelectEvents(ctx, txn, wantEventIDs, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to SelectEvents: %s", err)
|
||||
}
|
||||
gotEventIDs := make([]string, len(gotEvents))
|
||||
for i := range gotEvents {
|
||||
gotEventIDs[i] = gotEvents[i].EventID()
|
||||
}
|
||||
if !reflect.DeepEqual(gotEventIDs, wantEventIDs) {
|
||||
return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventIDs)
|
||||
}
|
||||
|
||||
// Test that contains_url is correctly populated
|
||||
urlEv := room.CreateEvent(t, alice, "m.text", map[string]interface{}{
|
||||
"body": "test.txt",
|
||||
"url": "mxc://test.txt",
|
||||
})
|
||||
if _, err = tab.InsertEvent(ctx, txn, urlEv, nil, nil, nil, false); err != nil {
|
||||
return fmt.Errorf("failed to InsertEvent: %s", err)
|
||||
}
|
||||
wantEventID := []string{urlEv.EventID()}
|
||||
t := true
|
||||
gotEvents, err = tab.SelectEvents(ctx, txn, wantEventID, &gomatrixserverlib.RoomEventFilter{Limit: 1, ContainsURL: &t}, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to SelectEvents: %s", err)
|
||||
}
|
||||
gotEventIDs = make([]string, len(gotEvents))
|
||||
for i := range gotEvents {
|
||||
gotEventIDs[i] = gotEvents[i].EventID()
|
||||
}
|
||||
if !reflect.DeepEqual(gotEventIDs, wantEventID) {
|
||||
return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventID)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
91
syncapi/storage/tables/topology_test.go
Normal file
91
syncapi/storage/tables/topology_test.go
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
package tables_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
)
|
||||
|
||||
func newTopologyTable(t *testing.T, dbType test.DBType) (tables.Topology, *sql.DB, func()) {
|
||||
t.Helper()
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open db: %s", err)
|
||||
}
|
||||
|
||||
var tab tables.Topology
|
||||
switch dbType {
|
||||
case test.DBTypePostgres:
|
||||
tab, err = postgres.NewPostgresTopologyTable(db)
|
||||
case test.DBTypeSQLite:
|
||||
tab, err = sqlite3.NewSqliteTopologyTable(db)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make new table: %s", err)
|
||||
}
|
||||
return tab, db, close
|
||||
}
|
||||
|
||||
func TestTopologyTable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
alice := test.NewUser()
|
||||
room := test.NewRoom(t, alice)
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
tab, db, close := newTopologyTable(t, dbType)
|
||||
defer close()
|
||||
events := room.Events()
|
||||
err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
|
||||
var highestPos types.StreamPosition
|
||||
for i, ev := range events {
|
||||
topoPos, err := tab.InsertEventInTopology(ctx, txn, ev, types.StreamPosition(i))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to InsertEventInTopology: %s", err)
|
||||
}
|
||||
// topo pos = depth, depth starts at 1, hence 1+i
|
||||
if topoPos != types.StreamPosition(1+i) {
|
||||
return fmt.Errorf("got topo pos %d want %d", topoPos, 1+i)
|
||||
}
|
||||
highestPos = topoPos + 1
|
||||
}
|
||||
// check ordering works without limit
|
||||
eventIDs, err := tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
|
||||
}
|
||||
test.AssertEventIDsEqual(t, eventIDs, events[:])
|
||||
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
|
||||
}
|
||||
test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[:]))
|
||||
// check ordering works with limit
|
||||
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
|
||||
}
|
||||
test.AssertEventIDsEqual(t, eventIDs, events[:3])
|
||||
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
|
||||
}
|
||||
test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[len(events)-3:]))
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -3,9 +3,11 @@ package streams
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
|
@ -26,7 +28,8 @@ type PDUStreamProvider struct {
|
|||
|
||||
tasks chan func()
|
||||
workers atomic.Int32
|
||||
userAPI userapi.UserInternalAPI
|
||||
// userID+deviceID -> lazy loading cache
|
||||
lazyLoadCache *caching.LazyLoadCache
|
||||
}
|
||||
|
||||
func (p *PDUStreamProvider) worker() {
|
||||
|
|
@ -188,7 +191,7 @@ func (p *PDUStreamProvider) IncrementalSync(
|
|||
newPos = from
|
||||
for _, delta := range stateDeltas {
|
||||
var pos types.StreamPosition
|
||||
if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, r, delta, &eventFilter, req.Response); err != nil {
|
||||
if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, r, delta, &eventFilter, &stateFilter, req.Response); err != nil {
|
||||
req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed")
|
||||
return to
|
||||
}
|
||||
|
|
@ -203,12 +206,14 @@ func (p *PDUStreamProvider) IncrementalSync(
|
|||
return newPos
|
||||
}
|
||||
|
||||
// nolint:gocyclo
|
||||
func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
||||
ctx context.Context,
|
||||
device *userapi.Device,
|
||||
r types.Range,
|
||||
delta types.StateDelta,
|
||||
eventFilter *gomatrixserverlib.RoomEventFilter,
|
||||
stateFilter *gomatrixserverlib.StateFilter,
|
||||
res *types.Response,
|
||||
) (types.StreamPosition, error) {
|
||||
if delta.MembershipPos > 0 && delta.Membership == gomatrixserverlib.Leave {
|
||||
|
|
@ -225,13 +230,16 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
|||
eventFilter, true, true,
|
||||
)
|
||||
if err != nil {
|
||||
return r.From, err
|
||||
if err == sql.ErrNoRows {
|
||||
return r.To, nil
|
||||
}
|
||||
return r.From, fmt.Errorf("p.DB.RecentEvents: %w", err)
|
||||
}
|
||||
recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents)
|
||||
delta.StateEvents = removeDuplicates(delta.StateEvents, recentEvents) // roll back
|
||||
prevBatch, err := p.DB.GetBackwardTopologyPos(ctx, recentStreamEvents)
|
||||
if err != nil {
|
||||
return r.From, err
|
||||
return r.From, fmt.Errorf("p.DB.GetBackwardTopologyPos: %w", err)
|
||||
}
|
||||
|
||||
// If we didn't return any events at all then don't bother doing anything else.
|
||||
|
|
@ -247,7 +255,8 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
|||
// room that were returned.
|
||||
latestPosition := r.To
|
||||
updateLatestPosition := func(mostRecentEventID string) {
|
||||
if _, pos, err := p.DB.PositionInTopology(ctx, mostRecentEventID); err == nil {
|
||||
var pos types.StreamPosition
|
||||
if _, pos, err = p.DB.PositionInTopology(ctx, mostRecentEventID); err == nil {
|
||||
switch {
|
||||
case r.Backwards && pos > latestPosition:
|
||||
fallthrough
|
||||
|
|
@ -263,6 +272,16 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
|||
updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID())
|
||||
}
|
||||
|
||||
if stateFilter.LazyLoadMembers {
|
||||
delta.StateEvents, err = p.lazyLoadMembers(
|
||||
ctx, delta.RoomID, true, limited, stateFilter.IncludeRedundantMembers,
|
||||
device, recentEvents, delta.StateEvents,
|
||||
)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return r.From, fmt.Errorf("p.lazyLoadMembers: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
hasMembershipChange := false
|
||||
for _, recentEvent := range recentStreamEvents {
|
||||
if recentEvent.Type() == gomatrixserverlib.MRoomMember && recentEvent.StateKey() != nil {
|
||||
|
|
@ -322,12 +341,16 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
|||
wantFullState bool,
|
||||
device *userapi.Device,
|
||||
) (jr *types.JoinResponse, err error) {
|
||||
jr = types.NewJoinResponse()
|
||||
// TODO: When filters are added, we may need to call this multiple times to get enough events.
|
||||
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
|
||||
recentStreamEvents, limited, err := p.DB.RecentEvents(
|
||||
ctx, roomID, r, eventFilter, true, true,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return jr, nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -402,7 +425,20 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
|||
// "Can sync a room with a message with a transaction id" - which does a complete sync to check.
|
||||
recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents)
|
||||
stateEvents = removeDuplicates(stateEvents, recentEvents)
|
||||
jr = types.NewJoinResponse()
|
||||
|
||||
if stateFilter.LazyLoadMembers {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stateEvents, err = p.lazyLoadMembers(ctx, roomID,
|
||||
false, limited, stateFilter.IncludeRedundantMembers,
|
||||
device, recentEvents, stateEvents,
|
||||
)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
jr.Summary.JoinedMemberCount = &joinedCount
|
||||
jr.Summary.InvitedMemberCount = &invitedCount
|
||||
jr.Timeline.PrevBatch = prevBatch
|
||||
|
|
@ -412,6 +448,69 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
|||
return jr, nil
|
||||
}
|
||||
|
||||
func (p *PDUStreamProvider) lazyLoadMembers(
|
||||
ctx context.Context, roomID string,
|
||||
incremental, limited, includeRedundant bool,
|
||||
device *userapi.Device,
|
||||
timelineEvents, stateEvents []*gomatrixserverlib.HeaderedEvent,
|
||||
) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||
if len(timelineEvents) == 0 {
|
||||
return stateEvents, nil
|
||||
}
|
||||
// Work out which memberships to include
|
||||
timelineUsers := make(map[string]struct{})
|
||||
if !incremental {
|
||||
timelineUsers[device.UserID] = struct{}{}
|
||||
}
|
||||
// Add all users the client doesn't know about yet to a list
|
||||
for _, event := range timelineEvents {
|
||||
// Membership is not yet cached, add it to the list
|
||||
if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, event.Sender()); !ok {
|
||||
timelineUsers[event.Sender()] = struct{}{}
|
||||
}
|
||||
}
|
||||
// Preallocate with the same amount, even if it will end up with fewer values
|
||||
newStateEvents := make([]*gomatrixserverlib.HeaderedEvent, 0, len(stateEvents))
|
||||
// Remove existing membership events we don't care about, e.g. users not in the timeline.events
|
||||
for _, event := range stateEvents {
|
||||
if event.Type() == gomatrixserverlib.MRoomMember && event.StateKey() != nil {
|
||||
// If this is a gapped incremental sync, we still want this membership
|
||||
isGappedIncremental := limited && incremental
|
||||
// We want this users membership event, keep it in the list
|
||||
_, ok := timelineUsers[event.Sender()]
|
||||
wantMembership := ok || isGappedIncremental
|
||||
if wantMembership {
|
||||
newStateEvents = append(newStateEvents, event)
|
||||
if !includeRedundant {
|
||||
p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, event.Sender(), event.EventID())
|
||||
}
|
||||
delete(timelineUsers, event.Sender())
|
||||
}
|
||||
} else {
|
||||
newStateEvents = append(newStateEvents, event)
|
||||
}
|
||||
}
|
||||
wantUsers := make([]string, 0, len(timelineUsers))
|
||||
for userID := range timelineUsers {
|
||||
wantUsers = append(wantUsers, userID)
|
||||
}
|
||||
// Query missing membership events
|
||||
memberships, err := p.DB.GetStateEventsForRoom(ctx, roomID, &gomatrixserverlib.StateFilter{
|
||||
Limit: 100,
|
||||
Senders: &wantUsers,
|
||||
Types: &[]string{gomatrixserverlib.MRoomMember},
|
||||
})
|
||||
if err != nil {
|
||||
return stateEvents, err
|
||||
}
|
||||
// cache the membership events
|
||||
for _, membership := range memberships {
|
||||
p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, membership.Sender(), membership.EventID())
|
||||
}
|
||||
stateEvents = append(newStateEvents, memberships...)
|
||||
return stateEvents, nil
|
||||
}
|
||||
|
||||
// addIgnoredUsersToFilter adds ignored users to the eventfilter and
|
||||
// the syncreq itself for further use in streams.
|
||||
func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error {
|
||||
|
|
@ -423,8 +522,12 @@ func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *ty
|
|||
return err
|
||||
}
|
||||
req.IgnoredUsers = *ignores
|
||||
userList := make([]string, 0, len(ignores.List))
|
||||
for userID := range ignores.List {
|
||||
eventFilter.NotSenders = append(eventFilter.NotSenders, userID)
|
||||
userList = append(userList, userID)
|
||||
}
|
||||
if len(userList) > 0 {
|
||||
eventFilter.NotSenders = &userList
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,12 +27,12 @@ type Streams struct {
|
|||
func NewSyncStreamProviders(
|
||||
d storage.Database, userAPI userapi.UserInternalAPI,
|
||||
rsAPI rsapi.RoomserverInternalAPI, keyAPI keyapi.KeyInternalAPI,
|
||||
eduCache *caching.EDUCache, notifier *notifier.Notifier,
|
||||
eduCache *caching.EDUCache, lazyLoadCache *caching.LazyLoadCache, notifier *notifier.Notifier,
|
||||
) *Streams {
|
||||
streams := &Streams{
|
||||
PDUStreamProvider: &PDUStreamProvider{
|
||||
StreamProvider: StreamProvider{DB: d},
|
||||
userAPI: userAPI,
|
||||
lazyLoadCache: lazyLoadCache,
|
||||
},
|
||||
TypingStreamProvider: &TypingStreamProvider{
|
||||
StreamProvider: StreamProvider{DB: d},
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
package sync
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
|
@ -60,10 +61,10 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
|
|||
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||
return nil, fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
|
||||
}
|
||||
if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err != nil {
|
||||
if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err != nil && err != sql.ErrNoRows {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("syncDB.GetFilter failed")
|
||||
return nil, fmt.Errorf("syncDB.GetFilter: %w", err)
|
||||
} else {
|
||||
} else if f != nil {
|
||||
filter = *f
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -57,8 +57,12 @@ func AddPublicRoutes(
|
|||
}
|
||||
|
||||
eduCache := caching.NewTypingCache()
|
||||
lazyLoadCache, err := caching.NewLazyLoadCache()
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panicf("failed to create lazy loading cache")
|
||||
}
|
||||
notifier := notifier.NewNotifier()
|
||||
streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, notifier)
|
||||
streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, lazyLoadCache, notifier)
|
||||
notifier.SetCurrentPosition(streams.Latest(context.Background()))
|
||||
if err = notifier.Load(context.Background(), syncDB); err != nil {
|
||||
logrus.WithError(err).Panicf("failed to load notifier ")
|
||||
|
|
|
|||
|
|
@ -312,10 +312,10 @@ Inbound federation can return events
|
|||
Inbound federation can return missing events for world_readable visibility
|
||||
Inbound federation can return missing events for invite visibility
|
||||
Inbound federation can get public room list
|
||||
POST /rooms/:room_id/redact/:event_id as power user redacts message
|
||||
POST /rooms/:room_id/redact/:event_id as original message sender redacts message
|
||||
POST /rooms/:room_id/redact/:event_id as random user does not redact message
|
||||
POST /redact disallows redaction of event in different room
|
||||
PUT /rooms/:room_id/redact/:event_id/:txn_id as power user redacts message
|
||||
PUT /rooms/:room_id/redact/:event_id/:txn_id as original message sender redacts message
|
||||
PUT /rooms/:room_id/redact/:event_id/:txn_id as random user does not redact message
|
||||
PUT /redact disallows redaction of event in different room
|
||||
An event which redacts itself should be ignored
|
||||
A pair of events which redact each other should be ignored
|
||||
Redaction of a redaction redacts the redaction reason
|
||||
|
|
@ -697,3 +697,16 @@ Room state after a rejected state event is the same as before
|
|||
Ignore user in existing room
|
||||
Ignore invite in full sync
|
||||
Ignore invite in incremental sync
|
||||
A filtered timeline reaches its limit
|
||||
A change to displayname should not result in a full state sync
|
||||
Can fetch images in room
|
||||
The only membership state included in an initial sync is for all the senders in the timeline
|
||||
The only membership state included in an incremental sync is for senders in the timeline
|
||||
Old members are included in gappy incr LL sync if they start speaking
|
||||
We do send redundant membership state across incremental syncs if asked
|
||||
Rejecting invite over federation doesn't break incremental /sync
|
||||
Gapped incremental syncs include all state changes
|
||||
Old leaves are present in gapped incremental syncs
|
||||
Leaves are present in non-gapped incremental syncs
|
||||
Members from the gap are included in gappy incr LL sync
|
||||
Presence can be set from sync
|
||||
59
test/db.go
59
test/db.go
|
|
@ -15,12 +15,16 @@
|
|||
package test
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"testing"
|
||||
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
type DBType int
|
||||
|
|
@ -30,7 +34,7 @@ var DBTypePostgres DBType = 2
|
|||
|
||||
var Quiet = false
|
||||
|
||||
func createLocalDB(dbName string) string {
|
||||
func createLocalDB(dbName string) {
|
||||
if !Quiet {
|
||||
fmt.Println("Note: tests require a postgres install accessible to the current user")
|
||||
}
|
||||
|
|
@ -43,7 +47,29 @@ func createLocalDB(dbName string) string {
|
|||
if err != nil && !Quiet {
|
||||
fmt.Println("createLocalDB returned error:", err)
|
||||
}
|
||||
return dbName
|
||||
}
|
||||
|
||||
func createRemoteDB(t *testing.T, dbName, user, connStr string) {
|
||||
db, err := sql.Open("postgres", connStr+" dbname=postgres")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open postgres conn with connstr=%s : %s", connStr, err)
|
||||
}
|
||||
_, err = db.Exec(fmt.Sprintf(`CREATE DATABASE %s;`, dbName))
|
||||
if err != nil {
|
||||
pqErr, ok := err.(*pq.Error)
|
||||
if !ok {
|
||||
t.Fatalf("failed to CREATE DATABASE: %s", err)
|
||||
}
|
||||
// we ignore duplicate database error as we expect this
|
||||
if pqErr.Code != "42P04" {
|
||||
t.Fatalf("failed to CREATE DATABASE with code=%s msg=%s", pqErr.Code, pqErr.Message)
|
||||
}
|
||||
}
|
||||
_, err = db.Exec(fmt.Sprintf(`GRANT ALL PRIVILEGES ON DATABASE %s TO %s`, dbName, user))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to GRANT: %s", err)
|
||||
}
|
||||
_ = db.Close()
|
||||
}
|
||||
|
||||
func currentUser() string {
|
||||
|
|
@ -64,6 +90,7 @@ func currentUser() string {
|
|||
// TODO: namespace for concurrent package tests
|
||||
func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, close func()) {
|
||||
if dbType == DBTypeSQLite {
|
||||
// this will be made in the current working directory which namespaces concurrent package runs correctly
|
||||
dbname := "dendrite_test.db"
|
||||
return fmt.Sprintf("file:%s", dbname), func() {
|
||||
err := os.Remove(dbname)
|
||||
|
|
@ -79,13 +106,9 @@ func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, clo
|
|||
if user == "" {
|
||||
user = currentUser()
|
||||
}
|
||||
dbName := os.Getenv("POSTGRES_DB")
|
||||
if dbName == "" {
|
||||
dbName = createLocalDB("dendrite_test")
|
||||
}
|
||||
connStr = fmt.Sprintf(
|
||||
"user=%s dbname=%s sslmode=disable",
|
||||
user, dbName,
|
||||
"user=%s sslmode=disable",
|
||||
user,
|
||||
)
|
||||
// optional vars, used in CI
|
||||
password := os.Getenv("POSTGRES_PASSWORD")
|
||||
|
|
@ -97,6 +120,25 @@ func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, clo
|
|||
connStr += fmt.Sprintf(" host=%s", host)
|
||||
}
|
||||
|
||||
// superuser database
|
||||
postgresDB := os.Getenv("POSTGRES_DB")
|
||||
// we cannot use 'dendrite_test' here else 2x concurrently running packages will try to use the same db.
|
||||
// instead, hash the current working directory, snaffle the first 16 bytes and append that to dendrite_test
|
||||
// and use that as the unique db name. We do this because packages are per-directory hence by hashing the
|
||||
// working (test) directory we ensure we get a consistent hash and don't hash against concurrent packages.
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatalf("cannot get working directory: %s", err)
|
||||
}
|
||||
hash := sha256.Sum256([]byte(wd))
|
||||
dbName := fmt.Sprintf("dendrite_test_%s", hex.EncodeToString(hash[:16]))
|
||||
if postgresDB == "" { // local server, use createdb
|
||||
createLocalDB(dbName)
|
||||
} else { // remote server, shell into the postgres user and CREATE DATABASE
|
||||
createRemoteDB(t, dbName, user, connStr)
|
||||
}
|
||||
connStr += fmt.Sprintf(" dbname=%s", dbName)
|
||||
|
||||
return connStr, func() {
|
||||
// Drop all tables on the database to get a fresh instance
|
||||
db, err := sql.Open("postgres", connStr)
|
||||
|
|
@ -121,6 +163,7 @@ func WithAllDatabases(t *testing.T, testFn func(t *testing.T, db DBType)) {
|
|||
for dbName, dbType := range dbs {
|
||||
dbt := dbType
|
||||
t.Run(dbName, func(tt *testing.T) {
|
||||
tt.Parallel()
|
||||
testFn(tt, dbt)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,7 +15,9 @@
|
|||
package test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ed25519"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
|
@ -49,3 +51,40 @@ func WithUnsigned(unsigned interface{}) eventModifier {
|
|||
e.unsigned = unsigned
|
||||
}
|
||||
}
|
||||
|
||||
// Reverse a list of events
|
||||
func Reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
|
||||
out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
|
||||
for i := 0; i < len(in); i++ {
|
||||
out[i] = in[len(in)-i-1]
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func AssertEventIDsEqual(t *testing.T, gotEventIDs []string, wants []*gomatrixserverlib.HeaderedEvent) {
|
||||
t.Helper()
|
||||
if len(gotEventIDs) != len(wants) {
|
||||
t.Fatalf("length mismatch: got %d events, want %d", len(gotEventIDs), len(wants))
|
||||
}
|
||||
for i := range wants {
|
||||
w := wants[i].EventID()
|
||||
g := gotEventIDs[i]
|
||||
if w != g {
|
||||
t.Errorf("event at index %d mismatch:\ngot %s\n\nwant %s", i, string(g), string(w))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AssertEventsEqual(t *testing.T, gots, wants []*gomatrixserverlib.HeaderedEvent) {
|
||||
t.Helper()
|
||||
if len(gots) != len(wants) {
|
||||
t.Fatalf("length mismatch: got %d events, want %d", len(gots), len(wants))
|
||||
}
|
||||
for i := range wants {
|
||||
w := wants[i].JSON()
|
||||
g := gots[i].JSON()
|
||||
if !bytes.Equal(w, g) {
|
||||
t.Errorf("event at index %d mismatch:\ngot %s\n\nwant %s", i, string(g), string(w))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -494,7 +494,7 @@ type PerformPusherDeletionRequest struct {
|
|||
type Pusher struct {
|
||||
SessionID int64 `json:"session_id,omitempty"`
|
||||
PushKey string `json:"pushkey"`
|
||||
PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"`
|
||||
PushKeyTS int64 `json:"pushkey_ts,omitempty"`
|
||||
Kind PusherKind `json:"kind"`
|
||||
AppID string `json:"app_id"`
|
||||
AppDisplayName string `json:"app_display_name"`
|
||||
|
|
|
|||
|
|
@ -653,7 +653,7 @@ func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.Perform
|
|||
return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart)
|
||||
}
|
||||
if req.Pusher.PushKeyTS == 0 {
|
||||
req.Pusher.PushKeyTS = gomatrixserverlib.AsTimestamp(time.Now())
|
||||
req.Pusher.PushKeyTS = int64(time.Now().Unix())
|
||||
}
|
||||
return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
|
@ -95,7 +94,7 @@ type pushersStatements struct {
|
|||
// Returns nil error success.
|
||||
func (s *pushersStatements) InsertPusher(
|
||||
ctx context.Context, txn *sql.Tx, session_id int64,
|
||||
pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
|
||||
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
||||
logrus.Debugf("Created pusher %d", session_id)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
|
@ -95,7 +94,7 @@ type pushersStatements struct {
|
|||
// Returns nil error success.
|
||||
func (s *pushersStatements) InsertPusher(
|
||||
ctx context.Context, txn *sql.Tx, session_id int64,
|
||||
pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
|
||||
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
|
||||
) error {
|
||||
_, err := s.insertPusherStmt.ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
||||
logrus.Debugf("Created pusher %d", session_id)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type AccountDataTable interface {
|
||||
|
|
@ -96,7 +95,7 @@ type ThreePIDTable interface {
|
|||
}
|
||||
|
||||
type PusherTable interface {
|
||||
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error
|
||||
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error
|
||||
SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error)
|
||||
DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error
|
||||
DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error
|
||||
|
|
|
|||
Loading…
Reference in a new issue