mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-01 03:03:10 -06:00
Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/phonehomestats
This commit is contained in:
commit
d76489aced
27
.github/workflows/dendrite.yml
vendored
27
.github/workflows/dendrite.yml
vendored
|
|
@ -73,6 +73,26 @@ jobs:
|
|||
timeout-minutes: 5
|
||||
name: Unit tests (Go ${{ matrix.go }})
|
||||
runs-on: ubuntu-latest
|
||||
# Service containers to run with `container-job`
|
||||
services:
|
||||
# Label used to access the service container
|
||||
postgres:
|
||||
# Docker Hub image
|
||||
image: postgres:13-alpine
|
||||
# Provide the password for postgres
|
||||
env:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: dendrite
|
||||
ports:
|
||||
# Maps tcp port 5432 on service container to the host
|
||||
- 5432:5432
|
||||
# Set health checks to wait until postgres has started
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
|
|
@ -91,7 +111,12 @@ jobs:
|
|||
key: ${{ runner.os }}-go${{ matrix.go }}-test-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go${{ matrix.go }}-test-
|
||||
- run: go test ./...
|
||||
- run: go test -p 1 ./...
|
||||
env:
|
||||
POSTGRES_HOST: localhost
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: dendrite
|
||||
|
||||
# build Dendrite for linux with different architectures and go versions
|
||||
build:
|
||||
|
|
|
|||
27
CHANGES.md
27
CHANGES.md
|
|
@ -1,5 +1,32 @@
|
|||
# Changelog
|
||||
|
||||
## Dendrite 0.8.1 (2022-04-07)
|
||||
|
||||
### Fixes
|
||||
|
||||
* A bug which could result in the sync API deadlocking due to lock contention in the notifier has been fixed
|
||||
|
||||
## Dendrite 0.8.0 (2022-04-07)
|
||||
|
||||
### Features
|
||||
|
||||
* Support for presence has been added
|
||||
* Presence is not enabled by default
|
||||
* The `global.presence.enable_inbound` and `global.presence.enable_outbound` configuration options allow configuring inbound and outbound presence separately
|
||||
* Support for room upgrades via the `/room/{roomID}/upgrade` endpoint has been added (contributed by [DavidSpenler](https://github.com/DavidSpenler), [alexkursell](https://github.com/alexkursell))
|
||||
* Support for ignoring users has been added
|
||||
* Joined and invite user counts are now sent in the `/sync` room summaries
|
||||
* Queued federation and stale device list updates will now be staggered at startup over an up-to 2 minute warm-up period, rather than happening all at once
|
||||
* Memory pressure created by the sync notifier has been reduced
|
||||
* The EDU server component has now been removed, with the work being moved to more relevant components
|
||||
|
||||
### Fixes
|
||||
|
||||
* It is now possible to set the `power_level_content_override` when creating a room to include power levels over 100
|
||||
* `/send_join` and `/state` responses will now not unmarshal the JSON twice
|
||||
* The stream event consumer for push notifications will no longer request membership events that are irrelevant
|
||||
* Appservices will no longer incorrectly receive state events twice
|
||||
|
||||
## Dendrite 0.7.0 (2022-03-25)
|
||||
|
||||
### Features
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ import (
|
|||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
|
||||
pineconeConnections "github.com/matrix-org/pinecone/connections"
|
||||
pineconeMulticast "github.com/matrix-org/pinecone/multicast"
|
||||
pineconeRouter "github.com/matrix-org/pinecone/router"
|
||||
pineconeSessions "github.com/matrix-org/pinecone/sessions"
|
||||
|
|
@ -71,11 +72,9 @@ type DendriteMonolith struct {
|
|||
PineconeRouter *pineconeRouter.Router
|
||||
PineconeMulticast *pineconeMulticast.Multicast
|
||||
PineconeQUIC *pineconeSessions.Sessions
|
||||
PineconeManager *pineconeConnections.ConnectionManager
|
||||
StorageDirectory string
|
||||
CacheDirectory string
|
||||
staticPeerURI string
|
||||
staticPeerMutex sync.RWMutex
|
||||
staticPeerAttempt chan struct{}
|
||||
listener net.Listener
|
||||
httpServer *http.Server
|
||||
processContext *process.ProcessContext
|
||||
|
|
@ -104,15 +103,8 @@ func (m *DendriteMonolith) SetMulticastEnabled(enabled bool) {
|
|||
}
|
||||
|
||||
func (m *DendriteMonolith) SetStaticPeer(uri string) {
|
||||
m.staticPeerMutex.Lock()
|
||||
m.staticPeerURI = strings.TrimSpace(uri)
|
||||
m.staticPeerMutex.Unlock()
|
||||
m.DisconnectType(int(pineconeRouter.PeerTypeRemote))
|
||||
if uri != "" {
|
||||
go func() {
|
||||
m.staticPeerAttempt <- struct{}{}
|
||||
}()
|
||||
}
|
||||
m.PineconeManager.RemovePeers()
|
||||
m.PineconeManager.AddPeer(strings.TrimSpace(uri))
|
||||
}
|
||||
|
||||
func (m *DendriteMonolith) DisconnectType(peertype int) {
|
||||
|
|
@ -210,43 +202,6 @@ func (m *DendriteMonolith) RegisterDevice(localpart, deviceID string) (string, e
|
|||
return loginRes.Device.AccessToken, nil
|
||||
}
|
||||
|
||||
func (m *DendriteMonolith) staticPeerConnect() {
|
||||
connected := map[string]bool{} // URI -> connected?
|
||||
attempt := func() {
|
||||
m.staticPeerMutex.RLock()
|
||||
uri := m.staticPeerURI
|
||||
m.staticPeerMutex.RUnlock()
|
||||
if uri == "" {
|
||||
return
|
||||
}
|
||||
for k := range connected {
|
||||
delete(connected, k)
|
||||
}
|
||||
for _, uri := range strings.Split(uri, ",") {
|
||||
connected[strings.TrimSpace(uri)] = false
|
||||
}
|
||||
for _, info := range m.PineconeRouter.Peers() {
|
||||
connected[info.URI] = true
|
||||
}
|
||||
for k, online := range connected {
|
||||
if !online {
|
||||
if err := conn.ConnectToPeer(m.PineconeRouter, k); err != nil {
|
||||
logrus.WithError(err).Error("Failed to connect to static peer")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-m.processContext.Context().Done():
|
||||
case <-m.staticPeerAttempt:
|
||||
attempt()
|
||||
case <-time.After(time.Second * 5):
|
||||
attempt()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// nolint:gocyclo
|
||||
func (m *DendriteMonolith) Start() {
|
||||
var err error
|
||||
|
|
@ -284,6 +239,7 @@ func (m *DendriteMonolith) Start() {
|
|||
m.PineconeRouter = pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
|
||||
m.PineconeQUIC = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), m.PineconeRouter, []string{"matrix"})
|
||||
m.PineconeMulticast = pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), m.PineconeRouter)
|
||||
m.PineconeManager = pineconeConnections.NewConnectionManager(m.PineconeRouter)
|
||||
|
||||
prefix := hex.EncodeToString(pk)
|
||||
cfg := &config.Dendrite{}
|
||||
|
|
@ -392,9 +348,6 @@ func (m *DendriteMonolith) Start() {
|
|||
|
||||
m.processContext = base.ProcessContext
|
||||
|
||||
m.staticPeerAttempt = make(chan struct{}, 1)
|
||||
go m.staticPeerConnect()
|
||||
|
||||
go func() {
|
||||
m.logger.Info("Listening on ", cfg.Global.ServerName)
|
||||
m.logger.Fatal(m.httpServer.Serve(m.PineconeQUIC.Protocol("matrix")))
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ type SyncAPIProducer struct {
|
|||
}
|
||||
|
||||
// SendData sends account data to the sync API server
|
||||
func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string, readMarker *eventutil.ReadMarkerJSON) error {
|
||||
func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string, readMarker *eventutil.ReadMarkerJSON, ignoredUsers *types.IgnoredUsers) error {
|
||||
m := &nats.Msg{
|
||||
Subject: p.TopicClientData,
|
||||
Header: nats.Header{},
|
||||
|
|
@ -50,9 +50,10 @@ func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string
|
|||
m.Header.Set(jetstream.UserID, userID)
|
||||
|
||||
data := eventutil.AccountData{
|
||||
RoomID: roomID,
|
||||
Type: dataType,
|
||||
ReadMarker: readMarker,
|
||||
RoomID: roomID,
|
||||
Type: dataType,
|
||||
ReadMarker: readMarker,
|
||||
IgnoredUsers: ignoredUsers,
|
||||
}
|
||||
var err error
|
||||
m.Data, err = json.Marshal(data)
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/clientapi/producers"
|
||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
|
||||
"github.com/matrix-org/util"
|
||||
|
|
@ -94,10 +95,10 @@ func SaveAccountData(
|
|||
}
|
||||
}
|
||||
|
||||
if dataType == "m.fully_read" {
|
||||
if dataType == "m.fully_read" || dataType == "m.push_rules" {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: jsonerror.Forbidden("Unable to set read marker"),
|
||||
JSON: jsonerror.Forbidden(fmt.Sprintf("Unable to modify %q using this API", dataType)),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -126,8 +127,14 @@ func SaveAccountData(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
|
||||
var ignoredUsers *types.IgnoredUsers
|
||||
if dataType == "m.ignored_user_list" {
|
||||
ignoredUsers = &types.IgnoredUsers{}
|
||||
_ = json.Unmarshal(body, ignoredUsers)
|
||||
}
|
||||
|
||||
// TODO: user API should do this since it's account data
|
||||
if err := syncProducer.SendData(userID, roomID, dataType, nil); err != nil {
|
||||
if err := syncProducer.SendData(userID, roomID, dataType, nil, ignoredUsers); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
|
@ -184,7 +191,7 @@ func SaveReadMarker(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
|
||||
if err := syncProducer.SendData(device.UserID, roomID, "m.fully_read", &r); err != nil {
|
||||
if err := syncProducer.SendData(device.UserID, roomID, "m.fully_read", &r, nil); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -70,7 +70,6 @@ func SetPresence(
|
|||
JSON: jsonerror.Unknown(fmt.Sprintf("Unknown presence '%s'.", presence.Presence)),
|
||||
}
|
||||
}
|
||||
|
||||
err := producer.SendPresence(req.Context(), userID, presenceStatus, presence.StatusMsg)
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("failed to update presence")
|
||||
|
|
|
|||
|
|
@ -180,7 +180,7 @@ func SetAvatarURL(
|
|||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
||||
if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, false); err != nil {
|
||||
if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -64,11 +64,6 @@ const (
|
|||
sessionIDLength = 24
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Register prometheus metrics. They must be registered to be exposed.
|
||||
prometheus.MustRegister(amtRegUsers)
|
||||
}
|
||||
|
||||
// sessionsDict keeps track of completed auth stages for each session.
|
||||
// It shouldn't be passed by value because it contains a mutex.
|
||||
type sessionsDict struct {
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ func PutTag(
|
|||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
||||
if err = syncProducer.SendData(userID, roomID, "m.tag", nil); err != nil {
|
||||
if err = syncProducer.SendData(userID, roomID, "m.tag", nil, nil); err != nil {
|
||||
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
|
||||
}
|
||||
|
||||
|
|
@ -151,7 +151,7 @@ func DeleteTag(
|
|||
}
|
||||
|
||||
// TODO: user API should do this since it's account data
|
||||
if err := syncProducer.SendData(userID, roomID, "m.tag", nil); err != nil {
|
||||
if err := syncProducer.SendData(userID, roomID, "m.tag", nil, nil); err != nil {
|
||||
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ import (
|
|||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
|
@ -60,6 +61,8 @@ func Setup(
|
|||
extRoomsProvider api.ExtraPublicRoomsProvider,
|
||||
mscCfg *config.MSCs, natsClient *nats.Conn,
|
||||
) {
|
||||
prometheus.MustRegister(amtRegUsers, sendEventDuration)
|
||||
|
||||
rateLimits := httputil.NewRateLimits(&cfg.RateLimiting)
|
||||
userInteractiveAuth := auth.NewUserInteractive(userAPI, cfg)
|
||||
|
||||
|
|
|
|||
|
|
@ -46,10 +46,6 @@ var (
|
|||
userRoomSendMutexes sync.Map // (roomID+userID) -> mutex. mutexes to ensure correct ordering of sendEvents
|
||||
)
|
||||
|
||||
func init() {
|
||||
prometheus.MustRegister(sendEventDuration)
|
||||
}
|
||||
|
||||
var sendEventDuration = prometheus.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Namespace: "dendrite",
|
||||
|
|
|
|||
|
|
@ -1,156 +0,0 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const usage = `Usage: %s
|
||||
|
||||
Create a single endpoint URL which clients can be pointed at.
|
||||
|
||||
The client-server API in Dendrite is split across multiple processes
|
||||
which listen on multiple ports. You cannot point a Matrix client at
|
||||
any of those ports, as there will be unimplemented functionality.
|
||||
In addition, all client-server API processes start with the additional
|
||||
path prefix '/api', which Matrix clients will be unaware of.
|
||||
|
||||
This tool will proxy requests for all client-server URLs and forward
|
||||
them to their respective process. It will also add the '/api' path
|
||||
prefix to incoming requests.
|
||||
|
||||
THIS TOOL IS FOR TESTING AND NOT INTENDED FOR PRODUCTION USE.
|
||||
|
||||
Arguments:
|
||||
|
||||
`
|
||||
|
||||
var (
|
||||
syncServerURL = flag.String("sync-api-server-url", "", "The base URL of the listening 'dendrite-sync-api-server' process. E.g. 'http://localhost:4200'")
|
||||
clientAPIURL = flag.String("client-api-server-url", "", "The base URL of the listening 'dendrite-client-api-server' process. E.g. 'http://localhost:4321'")
|
||||
mediaAPIURL = flag.String("media-api-server-url", "", "The base URL of the listening 'dendrite-media-api-server' process. E.g. 'http://localhost:7779'")
|
||||
bindAddress = flag.String("bind-address", ":8008", "The listening port for the proxy.")
|
||||
certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS")
|
||||
keyFile = flag.String("tls-key", "", "The PEM private key to use for TLS")
|
||||
)
|
||||
|
||||
func makeProxy(targetURL string) (*httputil.ReverseProxy, error) {
|
||||
targetURL = strings.TrimSuffix(targetURL, "/")
|
||||
|
||||
// Check that we can parse the URL.
|
||||
_, err := url.Parse(targetURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &httputil.ReverseProxy{
|
||||
Director: func(req *http.Request) {
|
||||
// URL.Path() removes the % escaping from the path.
|
||||
// The % encoding will be added back when the url is encoded
|
||||
// when the request is forwarded.
|
||||
// This means that we will lose any unessecary escaping from the URL.
|
||||
// Pratically this means that any distinction between '%2F' and '/'
|
||||
// in the URL will be lost by the time it reaches the target.
|
||||
path := req.URL.Path
|
||||
log.WithFields(log.Fields{
|
||||
"path": path,
|
||||
"url": targetURL,
|
||||
"method": req.Method,
|
||||
}).Print("proxying request")
|
||||
newURL, err := url.Parse(targetURL)
|
||||
// Set the path separately as we need to preserve '#' characters
|
||||
// that would otherwise be interpreted as being the start of a URL
|
||||
// fragment.
|
||||
newURL.Path += path
|
||||
if err != nil {
|
||||
// We already checked that we can parse the URL
|
||||
// So this shouldn't ever get hit.
|
||||
panic(err)
|
||||
}
|
||||
// Copy the query parameters from the request.
|
||||
newURL.RawQuery = req.URL.RawQuery
|
||||
req.URL = newURL
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, usage, os.Args[0])
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if *syncServerURL == "" {
|
||||
flag.Usage()
|
||||
fmt.Fprintln(os.Stderr, "no --sync-api-server-url specified.")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if *clientAPIURL == "" {
|
||||
flag.Usage()
|
||||
fmt.Fprintln(os.Stderr, "no --client-api-server-url specified.")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if *mediaAPIURL == "" {
|
||||
flag.Usage()
|
||||
fmt.Fprintln(os.Stderr, "no --media-api-server-url specified.")
|
||||
os.Exit(1)
|
||||
}
|
||||
syncProxy, err := makeProxy(*syncServerURL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
clientProxy, err := makeProxy(*clientAPIURL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
mediaProxy, err := makeProxy(*mediaAPIURL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
http.Handle("/_matrix/client/r0/sync", syncProxy)
|
||||
http.Handle("/_matrix/media/v1/", mediaProxy)
|
||||
http.Handle("/", clientProxy)
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: *bindAddress,
|
||||
ReadTimeout: 1 * time.Minute, // how long we wait for the client to send the entire request (after connection accept)
|
||||
WriteTimeout: 5 * time.Minute, // how long the proxy has to write the full response
|
||||
}
|
||||
|
||||
fmt.Println("Proxying requests to:")
|
||||
fmt.Println(" /_matrix/client/r0/sync => ", *syncServerURL+"/api/_matrix/client/r0/sync")
|
||||
fmt.Println(" /_matrix/media/v1 => ", *mediaAPIURL+"/api/_matrix/media/v1")
|
||||
fmt.Println(" /* => ", *clientAPIURL+"/api/*")
|
||||
fmt.Println("Listening on ", *bindAddress)
|
||||
if *certFile != "" && *keyFile != "" {
|
||||
panic(srv.ListenAndServeTLS(*certFile, *keyFile))
|
||||
} else {
|
||||
panic(srv.ListenAndServe())
|
||||
}
|
||||
}
|
||||
|
|
@ -25,7 +25,6 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
|
@ -47,6 +46,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/userapi"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
pineconeConnections "github.com/matrix-org/pinecone/connections"
|
||||
pineconeMulticast "github.com/matrix-org/pinecone/multicast"
|
||||
pineconeRouter "github.com/matrix-org/pinecone/router"
|
||||
pineconeSessions "github.com/matrix-org/pinecone/sessions"
|
||||
|
|
@ -90,6 +90,13 @@ func main() {
|
|||
}
|
||||
|
||||
pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
|
||||
pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
|
||||
pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter)
|
||||
pManager := pineconeConnections.NewConnectionManager(pRouter)
|
||||
pMulticast.Start()
|
||||
if instancePeer != nil && *instancePeer != "" {
|
||||
pManager.AddPeer(*instancePeer)
|
||||
}
|
||||
|
||||
go func() {
|
||||
listener, err := net.Listen("tcp", *instanceListen)
|
||||
|
|
@ -119,36 +126,6 @@ func main() {
|
|||
}
|
||||
}()
|
||||
|
||||
pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
|
||||
pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter)
|
||||
pMulticast.Start()
|
||||
|
||||
connectToStaticPeer := func() {
|
||||
connected := map[string]bool{} // URI -> connected?
|
||||
for _, uri := range strings.Split(*instancePeer, ",") {
|
||||
connected[strings.TrimSpace(uri)] = false
|
||||
}
|
||||
attempt := func() {
|
||||
for k := range connected {
|
||||
connected[k] = false
|
||||
}
|
||||
for _, info := range pRouter.Peers() {
|
||||
connected[info.URI] = true
|
||||
}
|
||||
for k, online := range connected {
|
||||
if !online {
|
||||
if err := conn.ConnectToPeer(pRouter, k); err != nil {
|
||||
logrus.WithError(err).Error("Failed to connect to static peer")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for {
|
||||
attempt()
|
||||
time.Sleep(time.Second * 5)
|
||||
}
|
||||
}
|
||||
|
||||
cfg := &config.Dendrite{}
|
||||
cfg.Defaults(true)
|
||||
cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk))
|
||||
|
|
@ -268,7 +245,6 @@ func main() {
|
|||
Handler: pMux,
|
||||
}
|
||||
|
||||
go connectToStaticPeer()
|
||||
go func() {
|
||||
pubkey := pRouter.PublicKey()
|
||||
logrus.Info("Listening on ", hex.EncodeToString(pubkey[:]))
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ import (
|
|||
"encoding/hex"
|
||||
"fmt"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/matrix-org/dendrite/appservice"
|
||||
|
|
@ -44,6 +43,7 @@ import (
|
|||
|
||||
_ "github.com/matrix-org/go-sqlite3-js"
|
||||
|
||||
pineconeConnections "github.com/matrix-org/pinecone/connections"
|
||||
pineconeRouter "github.com/matrix-org/pinecone/router"
|
||||
pineconeSessions "github.com/matrix-org/pinecone/sessions"
|
||||
)
|
||||
|
|
@ -154,6 +154,8 @@ func startup() {
|
|||
|
||||
pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false)
|
||||
pSessions := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"})
|
||||
pManager := pineconeConnections.NewConnectionManager(pRouter)
|
||||
pManager.AddPeer("wss://pinecone.matrix.org/public")
|
||||
|
||||
cfg := &config.Dendrite{}
|
||||
cfg.Defaults(true)
|
||||
|
|
@ -237,20 +239,4 @@ func startup() {
|
|||
}
|
||||
s.ListenAndServe("fetch")
|
||||
}()
|
||||
|
||||
// Connect to the static peer
|
||||
go func() {
|
||||
for {
|
||||
if pRouter.PeerCount(pineconeRouter.PeerTypeRemote) == 0 {
|
||||
if err := conn.ConnectToPeer(pRouter, publicPeer); err != nil {
|
||||
logrus.WithError(err).Error("Failed to connect to static peer")
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-base.ProcessContext.Context().Done():
|
||||
return
|
||||
case <-time.After(time.Second * 5):
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,138 +0,0 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const usage = `Usage: %s
|
||||
|
||||
Create a single endpoint URL which remote matrix servers can be pointed at.
|
||||
|
||||
The server-server API in Dendrite is split across multiple processes
|
||||
which listen on multiple ports. You cannot point a Matrix server at
|
||||
any of those ports, as there will be unimplemented functionality.
|
||||
In addition, all server-server API processes start with the additional
|
||||
path prefix '/api', which Matrix servers will be unaware of.
|
||||
|
||||
This tool will proxy requests for all server-server URLs and forward
|
||||
them to their respective process. It will also add the '/api' path
|
||||
prefix to incoming requests.
|
||||
|
||||
THIS TOOL IS FOR TESTING AND NOT INTENDED FOR PRODUCTION USE.
|
||||
|
||||
Arguments:
|
||||
|
||||
`
|
||||
|
||||
var (
|
||||
federationAPIURL = flag.String("federation-api-url", "", "The base URL of the listening 'dendrite-federation-api-server' process. E.g. 'http://localhost:4200'")
|
||||
mediaAPIURL = flag.String("media-api-server-url", "", "The base URL of the listening 'dendrite-media-api-server' process. E.g. 'http://localhost:7779'")
|
||||
bindAddress = flag.String("bind-address", ":8448", "The listening port for the proxy.")
|
||||
certFile = flag.String("tls-cert", "server.crt", "The PEM formatted X509 certificate to use for TLS")
|
||||
keyFile = flag.String("tls-key", "server.key", "The PEM private key to use for TLS")
|
||||
)
|
||||
|
||||
func makeProxy(targetURL string) (*httputil.ReverseProxy, error) {
|
||||
if !strings.HasSuffix(targetURL, "/") {
|
||||
targetURL += "/"
|
||||
}
|
||||
// Check that we can parse the URL.
|
||||
_, err := url.Parse(targetURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &httputil.ReverseProxy{
|
||||
Director: func(req *http.Request) {
|
||||
// URL.Path() removes the % escaping from the path.
|
||||
// The % encoding will be added back when the url is encoded
|
||||
// when the request is forwarded.
|
||||
// This means that we will lose any unessecary escaping from the URL.
|
||||
// Pratically this means that any distinction between '%2F' and '/'
|
||||
// in the URL will be lost by the time it reaches the target.
|
||||
path := req.URL.Path
|
||||
log.WithFields(log.Fields{
|
||||
"path": path,
|
||||
"url": targetURL,
|
||||
"method": req.Method,
|
||||
}).Print("proxying request")
|
||||
newURL, err := url.Parse(targetURL + path)
|
||||
if err != nil {
|
||||
// We already checked that we can parse the URL
|
||||
// So this shouldn't ever get hit.
|
||||
panic(err)
|
||||
}
|
||||
// Copy the query parameters from the request.
|
||||
newURL.RawQuery = req.URL.RawQuery
|
||||
req.URL = newURL
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, usage, os.Args[0])
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if *federationAPIURL == "" {
|
||||
flag.Usage()
|
||||
fmt.Fprintln(os.Stderr, "no --federation-api-url specified.")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if *mediaAPIURL == "" {
|
||||
flag.Usage()
|
||||
fmt.Fprintln(os.Stderr, "no --media-api-server-url specified.")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
federationProxy, err := makeProxy(*federationAPIURL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
mediaProxy, err := makeProxy(*mediaAPIURL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
http.Handle("/_matrix/media/v1/", mediaProxy)
|
||||
http.Handle("/", federationProxy)
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: *bindAddress,
|
||||
ReadTimeout: 1 * time.Minute, // how long we wait for the client to send the entire request (after connection accept)
|
||||
WriteTimeout: 5 * time.Minute, // how long the proxy has to write the full response
|
||||
}
|
||||
|
||||
fmt.Println("Proxying requests to:")
|
||||
fmt.Println(" /_matrix/media/v1 => ", *mediaAPIURL+"/api/_matrix/media/v1")
|
||||
fmt.Println(" /* => ", *federationAPIURL+"/api/*")
|
||||
fmt.Println("Listening on ", *bindAddress)
|
||||
panic(srv.ListenAndServeTLS(*certFile, *keyFile))
|
||||
}
|
||||
|
|
@ -29,6 +29,7 @@ import (
|
|||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
|
@ -53,6 +54,10 @@ func Setup(
|
|||
servers federationAPI.ServersInRoomProvider,
|
||||
producer *producers.SyncAPIProducer,
|
||||
) {
|
||||
prometheus.MustRegister(
|
||||
pduCountTotal, eduCountTotal,
|
||||
)
|
||||
|
||||
v2keysmux := keyMux.PathPrefix("/v2").Subrouter()
|
||||
v1fedmux := fedMux.PathPrefix("/v1").Subrouter()
|
||||
v2fedmux := fedMux.PathPrefix("/v2").Subrouter()
|
||||
|
|
|
|||
|
|
@ -74,12 +74,6 @@ var (
|
|||
)
|
||||
)
|
||||
|
||||
func init() {
|
||||
prometheus.MustRegister(
|
||||
pduCountTotal, eduCountTotal,
|
||||
)
|
||||
}
|
||||
|
||||
var inFlightTxnsPerOrigin sync.Map // transaction ID -> chan util.JSONResponse
|
||||
|
||||
// Send implements /_matrix/federation/v1/send/{txnID}
|
||||
|
|
@ -413,7 +407,6 @@ func (t *txnReq) processPresence(ctx context.Context, e gomatrixserverlib.EDU) e
|
|||
for _, content := range payload.Push {
|
||||
presence, ok := syncTypes.PresenceFromString(content.Presence)
|
||||
if !ok {
|
||||
logrus.Warnf("invalid presence '%s', skipping.", content.Presence)
|
||||
continue
|
||||
}
|
||||
if err := t.producer.SendPresence(ctx, content.UserID, presence, content.StatusMsg, content.LastActiveAgo); err != nil {
|
||||
|
|
|
|||
36
go.mod
36
go.mod
|
|
@ -12,20 +12,20 @@ require (
|
|||
github.com/MFAshby/stdemuxerhook v1.0.0
|
||||
github.com/Masterminds/semver/v3 v3.1.1
|
||||
github.com/codeclysm/extract v2.2.0+incompatible
|
||||
github.com/containerd/containerd v1.5.9 // indirect
|
||||
github.com/docker/docker v20.10.12+incompatible
|
||||
github.com/containerd/containerd v1.6.2 // indirect
|
||||
github.com/docker/docker v20.10.14+incompatible
|
||||
github.com/docker/go-connections v0.4.0
|
||||
github.com/frankban/quicktest v1.14.0 // indirect
|
||||
github.com/getsentry/sentry-go v0.12.0
|
||||
github.com/frankban/quicktest v1.14.3 // indirect
|
||||
github.com/getsentry/sentry-go v0.13.0
|
||||
github.com/gologme/log v1.3.0
|
||||
github.com/google/go-cmp v0.5.6
|
||||
github.com/google/uuid v1.2.0
|
||||
github.com/google/go-cmp v0.5.7
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/gorilla/mux v1.8.0
|
||||
github.com/gorilla/websocket v1.4.2
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/h2non/filetype v1.1.3 // indirect
|
||||
github.com/hashicorp/golang-lru v0.5.4
|
||||
github.com/juju/testing v0.0.0-20211215003918-77eb13d6cad2 // indirect
|
||||
github.com/lib/pq v1.10.4
|
||||
github.com/juju/testing v0.0.0-20220203020004-a0ff61f03494 // indirect
|
||||
github.com/lib/pq v1.10.5
|
||||
github.com/libp2p/go-libp2p v0.13.0
|
||||
github.com/libp2p/go-libp2p-circuit v0.4.0
|
||||
github.com/libp2p/go-libp2p-core v0.8.3
|
||||
|
|
@ -38,16 +38,16 @@ require (
|
|||
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
|
||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d
|
||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5
|
||||
github.com/matrix-org/pinecone v0.0.0-20220404141326-e526fa82f79d
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20220408160933-cf558306b56f
|
||||
github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48
|
||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
||||
github.com/mattn/go-sqlite3 v1.14.10
|
||||
github.com/morikuni/aec v1.0.0 // indirect
|
||||
github.com/nats-io/nats-server/v2 v2.7.4-0.20220309205833-773636c1c5bb
|
||||
github.com/nats-io/nats.go v1.13.1-0.20220308171302-2f2f6968e98d
|
||||
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
||||
github.com/ngrok/sqlmw v0.0.0-20211220175533-9d16fdc47b31
|
||||
github.com/opencontainers/image-spec v1.0.2 // indirect
|
||||
github.com/opentracing/opentracing-go v1.2.0
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/pkg/errors v0.9.1
|
||||
|
|
@ -60,14 +60,14 @@ require (
|
|||
github.com/uber/jaeger-lib v2.4.1+incompatible
|
||||
github.com/yggdrasil-network/yggdrasil-go v0.4.3
|
||||
go.uber.org/atomic v1.9.0
|
||||
golang.org/x/crypto v0.0.0-20220214200702-86341886e292
|
||||
golang.org/x/image v0.0.0-20211028202545-6944b10bf410
|
||||
golang.org/x/mobile v0.0.0-20220325161704-447654d348e3
|
||||
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd
|
||||
golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29
|
||||
golang.org/x/image v0.0.0-20220321031419-a8550c1d254a
|
||||
golang.org/x/mobile v0.0.0-20220407111146-e579adbbc4a2
|
||||
golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3
|
||||
golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 // indirect
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211
|
||||
gopkg.in/h2non/bimg.v1 v1.1.5
|
||||
gopkg.in/h2non/bimg.v1 v1.1.9
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
|
||||
nhooyr.io/websocket v1.8.7
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ package eventutil
|
|||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
)
|
||||
|
||||
// ErrProfileNoExists is returned when trying to lookup a user's profile that
|
||||
|
|
@ -26,9 +28,10 @@ var ErrProfileNoExists = errors.New("no known profile for given user ID")
|
|||
// AccountData represents account data sent from the client API server to the
|
||||
// sync API server
|
||||
type AccountData struct {
|
||||
RoomID string `json:"room_id"`
|
||||
Type string `json:"type"`
|
||||
ReadMarker *ReadMarkerJSON `json:"read_marker,omitempty"` // optional
|
||||
RoomID string `json:"room_id"`
|
||||
Type string `json:"type"`
|
||||
ReadMarker *ReadMarkerJSON `json:"read_marker,omitempty"` // optional
|
||||
IgnoredUsers *types.IgnoredUsers `json:"ignored_users,omitempty"` // optional
|
||||
}
|
||||
|
||||
type ReadMarkerJSON struct {
|
||||
|
|
|
|||
|
|
@ -169,8 +169,9 @@ func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request)
|
|||
return promhttp.InstrumentHandlerCounter(
|
||||
promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: metricsName,
|
||||
Help: "Total number of http requests for HTML resources",
|
||||
Name: metricsName,
|
||||
Help: "Total number of http requests for HTML resources",
|
||||
Namespace: "dendrite",
|
||||
},
|
||||
[]string{"code"},
|
||||
),
|
||||
|
|
@ -201,7 +202,28 @@ func MakeInternalAPI(metricsName string, f func(*http.Request) util.JSONResponse
|
|||
h.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
return http.HandlerFunc(withSpan)
|
||||
return promhttp.InstrumentHandlerCounter(
|
||||
promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: metricsName + "_requests_total",
|
||||
Help: "Total number of internal API calls",
|
||||
Namespace: "dendrite",
|
||||
},
|
||||
[]string{"code"},
|
||||
),
|
||||
promhttp.InstrumentHandlerResponseSize(
|
||||
promauto.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Namespace: "dendrite",
|
||||
Name: metricsName + "_response_size_bytes",
|
||||
Help: "A histogram of response sizes for requests.",
|
||||
Buckets: []float64{200, 500, 900, 1500, 5000, 15000, 50000, 100000},
|
||||
},
|
||||
[]string{},
|
||||
),
|
||||
http.HandlerFunc(withSpan),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// MakeFedAPI makes an http.Handler that checks matrix federation authentication.
|
||||
|
|
|
|||
|
|
@ -16,8 +16,8 @@ var build string
|
|||
|
||||
const (
|
||||
VersionMajor = 0
|
||||
VersionMinor = 7
|
||||
VersionPatch = 0
|
||||
VersionMinor = 8
|
||||
VersionPatch = 1
|
||||
VersionTag = "" // example: "rc1"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -167,6 +167,7 @@ func (r *Inputer) startWorkerForRoom(roomID string) {
|
|||
// will look to see if we have a worker for that room which has its
|
||||
// own consumer. If we don't, we'll start one.
|
||||
func (r *Inputer) Start() error {
|
||||
prometheus.MustRegister(roomserverInputBackpressure, processRoomEventDuration)
|
||||
_, err := r.JetStream.Subscribe(
|
||||
"", // This is blank because we specified it in BindStream.
|
||||
func(m *nats.Msg) {
|
||||
|
|
@ -421,10 +422,6 @@ func (r *Inputer) WriteOutputEvents(roomID string, updates []api.OutputEvent) er
|
|||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
prometheus.MustRegister(roomserverInputBackpressure)
|
||||
}
|
||||
|
||||
var roomserverInputBackpressure = prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: "dendrite",
|
||||
|
|
|
|||
|
|
@ -37,10 +37,6 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func init() {
|
||||
prometheus.MustRegister(processRoomEventDuration)
|
||||
}
|
||||
|
||||
// TODO: Does this value make sense?
|
||||
const MaximumMissingProcessingTime = time.Minute * 2
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package input_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
|
@ -12,30 +11,22 @@ import (
|
|||
"github.com/matrix-org/dendrite/roomserver/internal/input"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/nats-io/nats.go"
|
||||
)
|
||||
|
||||
func psqlConnectionString() config.DataSource {
|
||||
user := os.Getenv("POSTGRES_USER")
|
||||
if user == "" {
|
||||
user = "dendrite"
|
||||
}
|
||||
dbName := os.Getenv("POSTGRES_DB")
|
||||
if dbName == "" {
|
||||
dbName = "dendrite"
|
||||
}
|
||||
connStr := fmt.Sprintf(
|
||||
"user=%s dbname=%s sslmode=disable", user, dbName,
|
||||
)
|
||||
password := os.Getenv("POSTGRES_PASSWORD")
|
||||
if password != "" {
|
||||
connStr += fmt.Sprintf(" password=%s", password)
|
||||
}
|
||||
host := os.Getenv("POSTGRES_HOST")
|
||||
if host != "" {
|
||||
connStr += fmt.Sprintf(" host=%s", host)
|
||||
}
|
||||
return config.DataSource(connStr)
|
||||
var js nats.JetStreamContext
|
||||
var jc *nats.Conn
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
var pc *process.ProcessContext
|
||||
pc, js, jc = jetstream.PrepareForTests()
|
||||
code := m.Run()
|
||||
pc.ShutdownDendrite()
|
||||
pc.WaitForComponentsToFinish()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestSingleTransactionOnInput(t *testing.T) {
|
||||
|
|
@ -63,7 +54,7 @@ func TestSingleTransactionOnInput(t *testing.T) {
|
|||
}
|
||||
db, err := storage.Open(
|
||||
&config.DatabaseOptions{
|
||||
ConnectionString: psqlConnectionString(),
|
||||
ConnectionString: "",
|
||||
MaxOpenConnections: 1,
|
||||
MaxIdleConnections: 1,
|
||||
},
|
||||
|
|
@ -74,7 +65,9 @@ func TestSingleTransactionOnInput(t *testing.T) {
|
|||
t.SkipNow()
|
||||
}
|
||||
inputter := &input.Inputer{
|
||||
DB: db,
|
||||
DB: db,
|
||||
JetStream: js,
|
||||
NATSClient: jc,
|
||||
}
|
||||
res := &api.InputRoomEventsResponse{}
|
||||
inputter.InputRoomEvents(
|
||||
|
|
|
|||
|
|
@ -13,12 +13,22 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
|
||||
natsserver "github.com/nats-io/nats-server/v2/server"
|
||||
"github.com/nats-io/nats.go"
|
||||
natsclient "github.com/nats-io/nats.go"
|
||||
)
|
||||
|
||||
var natsServer *natsserver.Server
|
||||
var natsServerMutex sync.Mutex
|
||||
|
||||
func PrepareForTests() (*process.ProcessContext, nats.JetStreamContext, *nats.Conn) {
|
||||
cfg := &config.Dendrite{}
|
||||
cfg.Defaults(true)
|
||||
cfg.Global.JetStream.InMemory = true
|
||||
pc := process.NewProcessContext()
|
||||
js, jc := Prepare(pc, &cfg.Global.JetStream)
|
||||
return pc, js, jc
|
||||
}
|
||||
|
||||
func Prepare(process *process.ProcessContext, cfg *config.JetStream) (natsclient.JetStreamContext, *natsclient.Conn) {
|
||||
// check if we need an in-process NATS Server
|
||||
if len(cfg.Addresses) != 0 {
|
||||
|
|
|
|||
|
|
@ -119,6 +119,15 @@ func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msg *nats.Msg)
|
|||
return false
|
||||
}
|
||||
|
||||
if output.IgnoredUsers != nil {
|
||||
if err := s.db.UpdateIgnoresForUser(ctx, userID, output.IgnoredUsers); err != nil {
|
||||
log.WithError(err).WithFields(logrus.Fields{
|
||||
"user_id": userID,
|
||||
}).Errorf("Failed to update ignored users")
|
||||
sentry.CaptureException(err)
|
||||
}
|
||||
}
|
||||
|
||||
s.stream.Advance(streamPos)
|
||||
s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos})
|
||||
|
||||
|
|
|
|||
|
|
@ -25,28 +25,28 @@ import (
|
|||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// NOTE: ALL FUNCTIONS IN THIS FILE PREFIXED WITH _ ARE NOT THREAD-SAFE
|
||||
// AND MUST ONLY BE CALLED WHEN THE NOTIFIER LOCK IS HELD!
|
||||
|
||||
// Notifier will wake up sleeping requests when there is some new data.
|
||||
// It does not tell requests what that data is, only the sync position which
|
||||
// they can use to get at it. This is done to prevent races whereby we tell the caller
|
||||
// the event, but the token has already advanced by the time they fetch it, resulting
|
||||
// in missed events.
|
||||
type Notifier struct {
|
||||
lock *sync.RWMutex
|
||||
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
|
||||
roomIDToJoinedUsers map[string]userIDSet
|
||||
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
|
||||
roomIDToPeekingDevices map[string]peekingDeviceSet
|
||||
// Protects currPos and userStreams.
|
||||
streamLock *sync.Mutex
|
||||
// The latest sync position
|
||||
currPos types.StreamingToken
|
||||
// A map of user_id => device_id => UserStream which can be used to wake a given user's /sync request.
|
||||
userDeviceStreams map[string]map[string]*UserDeviceStream
|
||||
// The last time we cleaned out stale entries from the userStreams map
|
||||
lastCleanUpTime time.Time
|
||||
// Protects roomIDToJoinedUsers and roomIDToPeekingDevices
|
||||
mapLock *sync.RWMutex
|
||||
// This map is reused to prevent allocations and GC pressure in SharedUsers.
|
||||
_sharedUsers map[string]struct{}
|
||||
_sharedUserMap map[string]struct{}
|
||||
}
|
||||
|
||||
// NewNotifier creates a new notifier set to the given sync position.
|
||||
|
|
@ -57,18 +57,17 @@ func NewNotifier() *Notifier {
|
|||
roomIDToJoinedUsers: make(map[string]userIDSet),
|
||||
roomIDToPeekingDevices: make(map[string]peekingDeviceSet),
|
||||
userDeviceStreams: make(map[string]map[string]*UserDeviceStream),
|
||||
streamLock: &sync.Mutex{},
|
||||
mapLock: &sync.RWMutex{},
|
||||
lock: &sync.RWMutex{},
|
||||
lastCleanUpTime: time.Now(),
|
||||
_sharedUsers: map[string]struct{}{},
|
||||
_sharedUserMap: map[string]struct{}{},
|
||||
}
|
||||
}
|
||||
|
||||
// SetCurrentPosition sets the current streaming positions.
|
||||
// This must be called directly after NewNotifier and initialising the streams.
|
||||
func (n *Notifier) SetCurrentPosition(currPos types.StreamingToken) {
|
||||
n.streamLock.Lock()
|
||||
defer n.streamLock.Unlock()
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
|
||||
n.currPos = currPos
|
||||
}
|
||||
|
|
@ -89,17 +88,16 @@ func (n *Notifier) OnNewEvent(
|
|||
) {
|
||||
// update the current position then notify relevant /sync streams.
|
||||
// This needs to be done PRIOR to waking up users as they will read this value.
|
||||
n.streamLock.Lock()
|
||||
defer n.streamLock.Unlock()
|
||||
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
n.currPos.ApplyUpdates(posUpdate)
|
||||
n.removeEmptyUserStreams()
|
||||
n._removeEmptyUserStreams()
|
||||
|
||||
if ev != nil {
|
||||
// Map this event's room_id to a list of joined users, and wake them up.
|
||||
usersToNotify := n.JoinedUsers(ev.RoomID())
|
||||
usersToNotify := n._joinedUsers(ev.RoomID())
|
||||
// Map this event's room_id to a list of peeking devices, and wake them up.
|
||||
peekingDevicesToNotify := n.PeekingDevices(ev.RoomID())
|
||||
peekingDevicesToNotify := n._peekingDevices(ev.RoomID())
|
||||
// If this is an invite, also add in the invitee to this list.
|
||||
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
|
||||
targetUserID := *ev.StateKey()
|
||||
|
|
@ -117,20 +115,20 @@ func (n *Notifier) OnNewEvent(
|
|||
// Manually append the new user's ID so they get notified
|
||||
// along all members in the room
|
||||
usersToNotify = append(usersToNotify, targetUserID)
|
||||
n.addJoinedUser(ev.RoomID(), targetUserID)
|
||||
n._addJoinedUser(ev.RoomID(), targetUserID)
|
||||
case gomatrixserverlib.Leave:
|
||||
fallthrough
|
||||
case gomatrixserverlib.Ban:
|
||||
n.removeJoinedUser(ev.RoomID(), targetUserID)
|
||||
n._removeJoinedUser(ev.RoomID(), targetUserID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
n.wakeupUsers(usersToNotify, peekingDevicesToNotify, n.currPos)
|
||||
n._wakeupUsers(usersToNotify, peekingDevicesToNotify, n.currPos)
|
||||
} else if roomID != "" {
|
||||
n.wakeupUsers(n.JoinedUsers(roomID), n.PeekingDevices(roomID), n.currPos)
|
||||
n._wakeupUsers(n._joinedUsers(roomID), n._peekingDevices(roomID), n.currPos)
|
||||
} else if len(userIDs) > 0 {
|
||||
n.wakeupUsers(userIDs, nil, n.currPos)
|
||||
n._wakeupUsers(userIDs, nil, n.currPos)
|
||||
} else {
|
||||
log.WithFields(log.Fields{
|
||||
"posUpdate": posUpdate.String,
|
||||
|
|
@ -141,22 +139,22 @@ func (n *Notifier) OnNewEvent(
|
|||
func (n *Notifier) OnNewAccountData(
|
||||
userID string, posUpdate types.StreamingToken,
|
||||
) {
|
||||
n.streamLock.Lock()
|
||||
defer n.streamLock.Unlock()
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
|
||||
n.currPos.ApplyUpdates(posUpdate)
|
||||
n.wakeupUsers([]string{userID}, nil, posUpdate)
|
||||
n._wakeupUsers([]string{userID}, nil, posUpdate)
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewPeek(
|
||||
roomID, userID, deviceID string,
|
||||
posUpdate types.StreamingToken,
|
||||
) {
|
||||
n.streamLock.Lock()
|
||||
defer n.streamLock.Unlock()
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
|
||||
n.currPos.ApplyUpdates(posUpdate)
|
||||
n.addPeekingDevice(roomID, userID, deviceID)
|
||||
n._addPeekingDevice(roomID, userID, deviceID)
|
||||
|
||||
// we don't wake up devices here given the roomserver consumer will do this shortly afterwards
|
||||
// by calling OnNewEvent.
|
||||
|
|
@ -166,11 +164,11 @@ func (n *Notifier) OnRetirePeek(
|
|||
roomID, userID, deviceID string,
|
||||
posUpdate types.StreamingToken,
|
||||
) {
|
||||
n.streamLock.Lock()
|
||||
defer n.streamLock.Unlock()
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
|
||||
n.currPos.ApplyUpdates(posUpdate)
|
||||
n.removePeekingDevice(roomID, userID, deviceID)
|
||||
n._removePeekingDevice(roomID, userID, deviceID)
|
||||
|
||||
// we don't wake up devices here given the roomserver consumer will do this shortly afterwards
|
||||
// by calling OnRetireEvent.
|
||||
|
|
@ -180,11 +178,11 @@ func (n *Notifier) OnNewSendToDevice(
|
|||
userID string, deviceIDs []string,
|
||||
posUpdate types.StreamingToken,
|
||||
) {
|
||||
n.streamLock.Lock()
|
||||
defer n.streamLock.Unlock()
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
|
||||
n.currPos.ApplyUpdates(posUpdate)
|
||||
n.wakeupUserDevice(userID, deviceIDs, n.currPos)
|
||||
n._wakeupUserDevice(userID, deviceIDs, n.currPos)
|
||||
}
|
||||
|
||||
// OnNewReceipt updates the current position
|
||||
|
|
@ -192,11 +190,11 @@ func (n *Notifier) OnNewTyping(
|
|||
roomID string,
|
||||
posUpdate types.StreamingToken,
|
||||
) {
|
||||
n.streamLock.Lock()
|
||||
defer n.streamLock.Unlock()
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
|
||||
n.currPos.ApplyUpdates(posUpdate)
|
||||
n.wakeupUsers(n.JoinedUsers(roomID), nil, n.currPos)
|
||||
n._wakeupUsers(n._joinedUsers(roomID), nil, n.currPos)
|
||||
}
|
||||
|
||||
// OnNewReceipt updates the current position
|
||||
|
|
@ -204,80 +202,84 @@ func (n *Notifier) OnNewReceipt(
|
|||
roomID string,
|
||||
posUpdate types.StreamingToken,
|
||||
) {
|
||||
n.streamLock.Lock()
|
||||
defer n.streamLock.Unlock()
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
|
||||
n.currPos.ApplyUpdates(posUpdate)
|
||||
n.wakeupUsers(n.JoinedUsers(roomID), nil, n.currPos)
|
||||
n._wakeupUsers(n._joinedUsers(roomID), nil, n.currPos)
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewKeyChange(
|
||||
posUpdate types.StreamingToken, wakeUserID, keyChangeUserID string,
|
||||
) {
|
||||
n.streamLock.Lock()
|
||||
defer n.streamLock.Unlock()
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
|
||||
n.currPos.ApplyUpdates(posUpdate)
|
||||
n.wakeupUsers([]string{wakeUserID}, nil, n.currPos)
|
||||
n._wakeupUsers([]string{wakeUserID}, nil, n.currPos)
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewInvite(
|
||||
posUpdate types.StreamingToken, wakeUserID string,
|
||||
) {
|
||||
n.streamLock.Lock()
|
||||
defer n.streamLock.Unlock()
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
|
||||
n.currPos.ApplyUpdates(posUpdate)
|
||||
n.wakeupUsers([]string{wakeUserID}, nil, n.currPos)
|
||||
n._wakeupUsers([]string{wakeUserID}, nil, n.currPos)
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewNotificationData(
|
||||
userID string,
|
||||
posUpdate types.StreamingToken,
|
||||
) {
|
||||
n.streamLock.Lock()
|
||||
defer n.streamLock.Unlock()
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
|
||||
n.currPos.ApplyUpdates(posUpdate)
|
||||
n.wakeupUsers([]string{userID}, nil, n.currPos)
|
||||
n._wakeupUsers([]string{userID}, nil, n.currPos)
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewPresence(
|
||||
posUpdate types.StreamingToken, userID string,
|
||||
) {
|
||||
n.streamLock.Lock()
|
||||
defer n.streamLock.Unlock()
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
|
||||
n.currPos.ApplyUpdates(posUpdate)
|
||||
sharedUsers := n.SharedUsers(userID)
|
||||
sharedUsers := n._sharedUsers(userID)
|
||||
sharedUsers = append(sharedUsers, userID)
|
||||
|
||||
n.wakeupUsers(sharedUsers, nil, n.currPos)
|
||||
n._wakeupUsers(sharedUsers, nil, n.currPos)
|
||||
}
|
||||
|
||||
func (n *Notifier) SharedUsers(userID string) []string {
|
||||
n.mapLock.RLock()
|
||||
defer n.mapLock.RUnlock()
|
||||
n._sharedUsers[userID] = struct{}{}
|
||||
n.lock.RLock()
|
||||
defer n.lock.RUnlock()
|
||||
return n._sharedUsers(userID)
|
||||
}
|
||||
|
||||
func (n *Notifier) _sharedUsers(userID string) []string {
|
||||
n._sharedUserMap[userID] = struct{}{}
|
||||
for roomID, users := range n.roomIDToJoinedUsers {
|
||||
if _, ok := users[userID]; !ok {
|
||||
continue
|
||||
}
|
||||
for _, userID := range n.JoinedUsers(roomID) {
|
||||
n._sharedUsers[userID] = struct{}{}
|
||||
for _, userID := range n._joinedUsers(roomID) {
|
||||
n._sharedUserMap[userID] = struct{}{}
|
||||
}
|
||||
}
|
||||
sharedUsers := make([]string, 0, len(n._sharedUsers)+1)
|
||||
for userID := range n._sharedUsers {
|
||||
sharedUsers := make([]string, 0, len(n._sharedUserMap)+1)
|
||||
for userID := range n._sharedUserMap {
|
||||
sharedUsers = append(sharedUsers, userID)
|
||||
delete(n._sharedUsers, userID)
|
||||
delete(n._sharedUserMap, userID)
|
||||
}
|
||||
return sharedUsers
|
||||
}
|
||||
|
||||
func (n *Notifier) IsSharedUser(userA, userB string) bool {
|
||||
n.mapLock.RLock()
|
||||
defer n.mapLock.RUnlock()
|
||||
n.lock.RLock()
|
||||
defer n.lock.RUnlock()
|
||||
var okA, okB bool
|
||||
for _, users := range n.roomIDToJoinedUsers {
|
||||
_, okA = users[userA]
|
||||
|
|
@ -301,18 +303,18 @@ func (n *Notifier) GetListener(req types.SyncRequest) UserDeviceStreamListener {
|
|||
// TODO: v1 /events 'peeking' has an 'explicit room ID' which is also tracked,
|
||||
// but given we don't do /events, let's pretend it doesn't exist.
|
||||
|
||||
n.streamLock.Lock()
|
||||
defer n.streamLock.Unlock()
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
|
||||
n.removeEmptyUserStreams()
|
||||
n._removeEmptyUserStreams()
|
||||
|
||||
return n.fetchUserDeviceStream(req.Device.UserID, req.Device.ID, true).GetListener(req.Context)
|
||||
return n._fetchUserDeviceStream(req.Device.UserID, req.Device.ID, true).GetListener(req.Context)
|
||||
}
|
||||
|
||||
// Load the membership states required to notify users correctly.
|
||||
func (n *Notifier) Load(ctx context.Context, db storage.Database) error {
|
||||
n.mapLock.Lock()
|
||||
defer n.mapLock.Unlock()
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
roomToUsers, err := db.AllJoinedUsersInRooms(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -330,8 +332,8 @@ func (n *Notifier) Load(ctx context.Context, db storage.Database) error {
|
|||
|
||||
// CurrentPosition returns the current sync position
|
||||
func (n *Notifier) CurrentPosition() types.StreamingToken {
|
||||
n.streamLock.Lock()
|
||||
defer n.streamLock.Unlock()
|
||||
n.lock.RLock()
|
||||
defer n.lock.RUnlock()
|
||||
|
||||
return n.currPos
|
||||
}
|
||||
|
|
@ -366,11 +368,11 @@ func (n *Notifier) setPeekingDevices(roomIDToPeekingDevices map[string][]types.P
|
|||
}
|
||||
}
|
||||
|
||||
// wakeupUsers will wake up the sync strems for all of the devices for all of the
|
||||
// _wakeupUsers will wake up the sync strems for all of the devices for all of the
|
||||
// specified user IDs, and also the specified peekingDevices
|
||||
func (n *Notifier) wakeupUsers(userIDs []string, peekingDevices []types.PeekingDevice, newPos types.StreamingToken) {
|
||||
func (n *Notifier) _wakeupUsers(userIDs []string, peekingDevices []types.PeekingDevice, newPos types.StreamingToken) {
|
||||
for _, userID := range userIDs {
|
||||
for _, stream := range n.fetchUserStreams(userID) {
|
||||
for _, stream := range n._fetchUserStreams(userID) {
|
||||
if stream == nil {
|
||||
continue
|
||||
}
|
||||
|
|
@ -380,28 +382,27 @@ func (n *Notifier) wakeupUsers(userIDs []string, peekingDevices []types.PeekingD
|
|||
|
||||
for _, peekingDevice := range peekingDevices {
|
||||
// TODO: don't bother waking up for devices whose users we already woke up
|
||||
if stream := n.fetchUserDeviceStream(peekingDevice.UserID, peekingDevice.DeviceID, false); stream != nil {
|
||||
if stream := n._fetchUserDeviceStream(peekingDevice.UserID, peekingDevice.DeviceID, false); stream != nil {
|
||||
stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// wakeupUserDevice will wake up the sync stream for a specific user device. Other
|
||||
// _wakeupUserDevice will wake up the sync stream for a specific user device. Other
|
||||
// device streams will be left alone.
|
||||
// nolint:unused
|
||||
func (n *Notifier) wakeupUserDevice(userID string, deviceIDs []string, newPos types.StreamingToken) {
|
||||
func (n *Notifier) _wakeupUserDevice(userID string, deviceIDs []string, newPos types.StreamingToken) {
|
||||
for _, deviceID := range deviceIDs {
|
||||
if stream := n.fetchUserDeviceStream(userID, deviceID, false); stream != nil {
|
||||
if stream := n._fetchUserDeviceStream(userID, deviceID, false); stream != nil {
|
||||
stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fetchUserDeviceStream retrieves a stream unique to the given device. If makeIfNotExists is true,
|
||||
// _fetchUserDeviceStream retrieves a stream unique to the given device. If makeIfNotExists is true,
|
||||
// a stream will be made for this device if one doesn't exist and it will be returned. This
|
||||
// function does not wait for data to be available on the stream.
|
||||
// NB: Callers should have locked the mutex before calling this function.
|
||||
func (n *Notifier) fetchUserDeviceStream(userID, deviceID string, makeIfNotExists bool) *UserDeviceStream {
|
||||
func (n *Notifier) _fetchUserDeviceStream(userID, deviceID string, makeIfNotExists bool) *UserDeviceStream {
|
||||
_, ok := n.userDeviceStreams[userID]
|
||||
if !ok {
|
||||
if !makeIfNotExists {
|
||||
|
|
@ -422,11 +423,10 @@ func (n *Notifier) fetchUserDeviceStream(userID, deviceID string, makeIfNotExist
|
|||
return stream
|
||||
}
|
||||
|
||||
// fetchUserStreams retrieves all streams for the given user. If makeIfNotExists is true,
|
||||
// _fetchUserStreams retrieves all streams for the given user. If makeIfNotExists is true,
|
||||
// a stream will be made for this user if one doesn't exist and it will be returned. This
|
||||
// function does not wait for data to be available on the stream.
|
||||
// NB: Callers should have locked the mutex before calling this function.
|
||||
func (n *Notifier) fetchUserStreams(userID string) []*UserDeviceStream {
|
||||
func (n *Notifier) _fetchUserStreams(userID string) []*UserDeviceStream {
|
||||
user, ok := n.userDeviceStreams[userID]
|
||||
if !ok {
|
||||
return []*UserDeviceStream{}
|
||||
|
|
@ -438,51 +438,41 @@ func (n *Notifier) fetchUserStreams(userID string) []*UserDeviceStream {
|
|||
return streams
|
||||
}
|
||||
|
||||
// Not thread-safe: must be called on the OnNewEvent goroutine only
|
||||
func (n *Notifier) addJoinedUser(roomID, userID string) {
|
||||
n.mapLock.Lock()
|
||||
defer n.mapLock.Unlock()
|
||||
func (n *Notifier) _addJoinedUser(roomID, userID string) {
|
||||
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
|
||||
n.roomIDToJoinedUsers[roomID] = make(userIDSet)
|
||||
}
|
||||
n.roomIDToJoinedUsers[roomID].add(userID)
|
||||
}
|
||||
|
||||
// Not thread-safe: must be called on the OnNewEvent goroutine only
|
||||
func (n *Notifier) removeJoinedUser(roomID, userID string) {
|
||||
n.mapLock.Lock()
|
||||
defer n.mapLock.Unlock()
|
||||
func (n *Notifier) _removeJoinedUser(roomID, userID string) {
|
||||
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
|
||||
n.roomIDToJoinedUsers[roomID] = make(userIDSet)
|
||||
}
|
||||
n.roomIDToJoinedUsers[roomID].remove(userID)
|
||||
}
|
||||
|
||||
// Not thread-safe: must be called on the OnNewEvent goroutine only
|
||||
func (n *Notifier) JoinedUsers(roomID string) (userIDs []string) {
|
||||
n.mapLock.RLock()
|
||||
defer n.mapLock.RUnlock()
|
||||
n.lock.RLock()
|
||||
defer n.lock.RUnlock()
|
||||
return n._joinedUsers(roomID)
|
||||
}
|
||||
|
||||
func (n *Notifier) _joinedUsers(roomID string) (userIDs []string) {
|
||||
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
|
||||
return
|
||||
}
|
||||
return n.roomIDToJoinedUsers[roomID].values()
|
||||
}
|
||||
|
||||
// Not thread-safe: must be called on the OnNewEvent goroutine only
|
||||
func (n *Notifier) addPeekingDevice(roomID, userID, deviceID string) {
|
||||
n.mapLock.Lock()
|
||||
defer n.mapLock.Unlock()
|
||||
func (n *Notifier) _addPeekingDevice(roomID, userID, deviceID string) {
|
||||
if _, ok := n.roomIDToPeekingDevices[roomID]; !ok {
|
||||
n.roomIDToPeekingDevices[roomID] = make(peekingDeviceSet)
|
||||
}
|
||||
n.roomIDToPeekingDevices[roomID].add(types.PeekingDevice{UserID: userID, DeviceID: deviceID})
|
||||
}
|
||||
|
||||
// Not thread-safe: must be called on the OnNewEvent goroutine only
|
||||
// nolint:unused
|
||||
func (n *Notifier) removePeekingDevice(roomID, userID, deviceID string) {
|
||||
n.mapLock.Lock()
|
||||
defer n.mapLock.Unlock()
|
||||
func (n *Notifier) _removePeekingDevice(roomID, userID, deviceID string) {
|
||||
if _, ok := n.roomIDToPeekingDevices[roomID]; !ok {
|
||||
n.roomIDToPeekingDevices[roomID] = make(peekingDeviceSet)
|
||||
}
|
||||
|
|
@ -490,24 +480,26 @@ func (n *Notifier) removePeekingDevice(roomID, userID, deviceID string) {
|
|||
n.roomIDToPeekingDevices[roomID].remove(types.PeekingDevice{UserID: userID, DeviceID: deviceID})
|
||||
}
|
||||
|
||||
// Not thread-safe: must be called on the OnNewEvent goroutine only
|
||||
func (n *Notifier) PeekingDevices(roomID string) (peekingDevices []types.PeekingDevice) {
|
||||
n.mapLock.RLock()
|
||||
defer n.mapLock.RUnlock()
|
||||
n.lock.RLock()
|
||||
defer n.lock.RUnlock()
|
||||
return n._peekingDevices(roomID)
|
||||
}
|
||||
|
||||
func (n *Notifier) _peekingDevices(roomID string) (peekingDevices []types.PeekingDevice) {
|
||||
if _, ok := n.roomIDToPeekingDevices[roomID]; !ok {
|
||||
return
|
||||
}
|
||||
return n.roomIDToPeekingDevices[roomID].values()
|
||||
}
|
||||
|
||||
// removeEmptyUserStreams iterates through the user stream map and removes any
|
||||
// _removeEmptyUserStreams iterates through the user stream map and removes any
|
||||
// that have been empty for a certain amount of time. This is a crude way of
|
||||
// ensuring that the userStreams map doesn't grow forver.
|
||||
// This should be called when the notifier gets called for whatever reason,
|
||||
// the function itself is responsible for ensuring it doesn't iterate too
|
||||
// often.
|
||||
// NB: Callers should have locked the mutex before calling this function.
|
||||
func (n *Notifier) removeEmptyUserStreams() {
|
||||
func (n *Notifier) _removeEmptyUserStreams() {
|
||||
// Only clean up now and again
|
||||
now := time.Now()
|
||||
if n.lastCleanUpTime.Add(time.Minute).After(now) {
|
||||
|
|
|
|||
|
|
@ -165,9 +165,9 @@ func TestCorrectStreamWakeup(t *testing.T) {
|
|||
|
||||
go func() {
|
||||
select {
|
||||
case <-streamone.signalChannel:
|
||||
case <-streamone.ch():
|
||||
awoken <- "one"
|
||||
case <-streamtwo.signalChannel:
|
||||
case <-streamtwo.ch():
|
||||
awoken <- "two"
|
||||
}
|
||||
}()
|
||||
|
|
@ -175,7 +175,7 @@ func TestCorrectStreamWakeup(t *testing.T) {
|
|||
time.Sleep(1 * time.Second)
|
||||
|
||||
wake := "two"
|
||||
n.wakeupUserDevice(alice, []string{wake}, syncPositionAfter)
|
||||
n._wakeupUserDevice(alice, []string{wake}, syncPositionAfter)
|
||||
|
||||
if result := <-awoken; result != wake {
|
||||
t.Fatalf("expected to wake %q, got %q", wake, result)
|
||||
|
|
@ -359,10 +359,10 @@ func waitForBlocking(s *UserDeviceStream, numBlocking uint) {
|
|||
// lockedFetchUserStream invokes Notifier.fetchUserStream, respecting Notifier.streamLock.
|
||||
// A new stream is made if it doesn't exist already.
|
||||
func lockedFetchUserStream(n *Notifier, userID, deviceID string) *UserDeviceStream {
|
||||
n.streamLock.Lock()
|
||||
defer n.streamLock.Unlock()
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
|
||||
return n.fetchUserDeviceStream(userID, deviceID, true)
|
||||
return n._fetchUserDeviceStream(userID, deviceID, true)
|
||||
}
|
||||
|
||||
func newTestSyncRequest(userID, deviceID string, since types.StreamingToken) types.SyncRequest {
|
||||
|
|
|
|||
|
|
@ -118,6 +118,12 @@ func (s *UserDeviceStream) TimeOfLastNonEmpty() time.Time {
|
|||
return s.timeOfLastChannel
|
||||
}
|
||||
|
||||
func (s *UserDeviceStream) ch() <-chan struct{} {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
return s.signalChannel
|
||||
}
|
||||
|
||||
// GetSyncPosition returns last sync position which the UserStream was
|
||||
// notified about
|
||||
func (s *UserDeviceStreamListener) GetSyncPosition() types.StreamingToken {
|
||||
|
|
|
|||
|
|
@ -60,7 +60,9 @@ func Context(
|
|||
Headers: nil,
|
||||
}
|
||||
}
|
||||
filter.Rooms = append(filter.Rooms, roomID)
|
||||
if filter.Rooms != nil {
|
||||
*filter.Rooms = append(*filter.Rooms, roomID)
|
||||
}
|
||||
|
||||
ctx := req.Context()
|
||||
membershipRes := roomserver.QueryMembershipForUserResponse{}
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ type Database interface {
|
|||
// DeletePeek deletes all peeks for a given room by a given user
|
||||
// Returns an error if there was a problem communicating with the database.
|
||||
DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error)
|
||||
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit.
|
||||
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last.
|
||||
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error)
|
||||
// EventPositionInTopology returns the depth and stream position of the given event.
|
||||
EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
|
||||
|
|
@ -150,6 +150,9 @@ type Database interface {
|
|||
SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error)
|
||||
|
||||
StreamToTopologicalPosition(ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool) (types.TopologyToken, error)
|
||||
|
||||
IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error)
|
||||
UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error
|
||||
}
|
||||
|
||||
type Presence interface {
|
||||
|
|
|
|||
|
|
@ -47,14 +47,10 @@ const selectBackwardExtremitiesForRoomSQL = "" +
|
|||
const deleteBackwardExtremitySQL = "" +
|
||||
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
|
||||
|
||||
const deleteBackwardExtremitiesForRoomSQL = "" +
|
||||
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1"
|
||||
|
||||
type backwardExtremitiesStatements struct {
|
||||
insertBackwardExtremityStmt *sql.Stmt
|
||||
selectBackwardExtremitiesForRoomStmt *sql.Stmt
|
||||
deleteBackwardExtremityStmt *sql.Stmt
|
||||
deleteBackwardExtremitiesForRoomStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
|
||||
|
|
@ -72,9 +68,6 @@ func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremiti
|
|||
if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
|
|
@ -113,10 +106,3 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
|
|||
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID)
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -233,9 +233,10 @@ func (s *currentRoomStateStatements) SelectCurrentState(
|
|||
excludeEventIDs []string,
|
||||
) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt)
|
||||
senders, notSenders := getSendersStateFilterFilter(stateFilter)
|
||||
rows, err := stmt.QueryContext(ctx, roomID,
|
||||
pq.StringArray(stateFilter.Senders),
|
||||
pq.StringArray(stateFilter.NotSenders),
|
||||
pq.StringArray(senders),
|
||||
pq.StringArray(notSenders),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)),
|
||||
stateFilter.ContainsURL,
|
||||
|
|
|
|||
|
|
@ -16,21 +16,45 @@ package postgres
|
|||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// filterConvertWildcardToSQL converts wildcards as defined in
|
||||
// https://matrix.org/docs/spec/client_server/r0.3.0.html#post-matrix-client-r0-user-userid-filter
|
||||
// to SQL wildcards that can be used with LIKE()
|
||||
func filterConvertTypeWildcardToSQL(values []string) []string {
|
||||
func filterConvertTypeWildcardToSQL(values *[]string) []string {
|
||||
if values == nil {
|
||||
// Return nil instead of []string{} so IS NULL can work correctly when
|
||||
// the return value is passed into SQL queries
|
||||
return nil
|
||||
}
|
||||
|
||||
ret := make([]string, len(values))
|
||||
for i := range values {
|
||||
ret[i] = strings.Replace(values[i], "*", "%", -1)
|
||||
v := *values
|
||||
ret := make([]string, len(v))
|
||||
for i := range v {
|
||||
ret[i] = strings.Replace(v[i], "*", "%", -1)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// TODO: Replace when Dendrite uses Go 1.18
|
||||
func getSendersRoomEventFilter(filter *gomatrixserverlib.RoomEventFilter) (senders []string, notSenders []string) {
|
||||
if filter.Senders != nil {
|
||||
senders = *filter.Senders
|
||||
}
|
||||
if filter.NotSenders != nil {
|
||||
notSenders = *filter.NotSenders
|
||||
}
|
||||
return senders, notSenders
|
||||
}
|
||||
|
||||
func getSendersStateFilterFilter(filter *gomatrixserverlib.StateFilter) (senders []string, notSenders []string) {
|
||||
if filter.Senders != nil {
|
||||
senders = *filter.Senders
|
||||
}
|
||||
if filter.NotSenders != nil {
|
||||
notSenders = *filter.NotSenders
|
||||
}
|
||||
return senders, notSenders
|
||||
}
|
||||
|
|
|
|||
87
syncapi/storage/postgres/ignores_table.go
Normal file
87
syncapi/storage/postgres/ignores_table.go
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
)
|
||||
|
||||
const ignoresSchema = `
|
||||
-- Stores data about ignoress
|
||||
CREATE TABLE IF NOT EXISTS syncapi_ignores (
|
||||
-- The user ID whose ignore list this belongs to.
|
||||
user_id TEXT NOT NULL,
|
||||
ignores_json TEXT NOT NULL,
|
||||
PRIMARY KEY(user_id)
|
||||
);
|
||||
`
|
||||
|
||||
const selectIgnoresSQL = "" +
|
||||
"SELECT ignores_json FROM syncapi_ignores WHERE user_id = $1"
|
||||
|
||||
const upsertIgnoresSQL = "" +
|
||||
"INSERT INTO syncapi_ignores (user_id, ignores_json) VALUES ($1, $2)" +
|
||||
" ON CONFLICT (user_id) DO UPDATE set ignores_json = $2"
|
||||
|
||||
type ignoresStatements struct {
|
||||
selectIgnoresStmt *sql.Stmt
|
||||
upsertIgnoresStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresIgnoresTable(db *sql.DB) (tables.Ignores, error) {
|
||||
_, err := db.Exec(ignoresSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := &ignoresStatements{}
|
||||
if s.selectIgnoresStmt, err = db.Prepare(selectIgnoresSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.upsertIgnoresStmt, err = db.Prepare(upsertIgnoresSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *ignoresStatements) SelectIgnores(
|
||||
ctx context.Context, userID string,
|
||||
) (*types.IgnoredUsers, error) {
|
||||
var ignoresData []byte
|
||||
err := s.selectIgnoresStmt.QueryRowContext(ctx, userID).Scan(&ignoresData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var ignores types.IgnoredUsers
|
||||
if err = json.Unmarshal(ignoresData, &ignores); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ignores, nil
|
||||
}
|
||||
|
||||
func (s *ignoresStatements) UpsertIgnores(
|
||||
ctx context.Context, userID string, ignores *types.IgnoredUsers,
|
||||
) error {
|
||||
ignoresJSON, err := json.Marshal(ignores)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.upsertIgnoresStmt.ExecContext(ctx, userID, ignoresJSON)
|
||||
return err
|
||||
}
|
||||
|
|
@ -56,12 +56,6 @@ const upsertMembershipSQL = "" +
|
|||
" ON CONFLICT ON CONSTRAINT syncapi_memberships_unique" +
|
||||
" DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6"
|
||||
|
||||
const selectMembershipSQL = "" +
|
||||
"SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" +
|
||||
" WHERE room_id = $1 AND user_id = $2 AND membership = ANY($3)" +
|
||||
" ORDER BY stream_pos DESC" +
|
||||
" LIMIT 1"
|
||||
|
||||
const selectMembershipCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM (" +
|
||||
" SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" +
|
||||
|
|
@ -69,7 +63,6 @@ const selectMembershipCountSQL = "" +
|
|||
|
||||
type membershipsStatements struct {
|
||||
upsertMembershipStmt *sql.Stmt
|
||||
selectMembershipStmt *sql.Stmt
|
||||
selectMembershipCountStmt *sql.Stmt
|
||||
}
|
||||
|
||||
|
|
@ -82,9 +75,6 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
|
|||
if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectMembershipStmt, err = db.Prepare(selectMembershipSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectMembershipCountStmt, err = db.Prepare(selectMembershipCountSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -111,14 +101,6 @@ func (s *membershipsStatements) UpsertMembership(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *membershipsStatements) SelectMembership(
|
||||
ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string,
|
||||
) (eventID string, streamPos, topologyPos types.StreamPosition, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectMembershipStmt)
|
||||
err = stmt.QueryRowContext(ctx, roomID, userID, memberships).Scan(&eventID, &streamPos, &topologyPos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *membershipsStatements) SelectMembershipCount(
|
||||
ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition,
|
||||
) (count int, err error) {
|
||||
|
|
|
|||
|
|
@ -204,11 +204,11 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
|
|||
stateFilter *gomatrixserverlib.StateFilter, roomIDs []string,
|
||||
) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt)
|
||||
|
||||
senders, notSenders := getSendersStateFilterFilter(stateFilter)
|
||||
rows, err := stmt.QueryContext(
|
||||
ctx, r.Low(), r.High(), pq.StringArray(roomIDs),
|
||||
pq.StringArray(stateFilter.Senders),
|
||||
pq.StringArray(stateFilter.NotSenders),
|
||||
pq.StringArray(senders),
|
||||
pq.StringArray(notSenders),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)),
|
||||
stateFilter.ContainsURL,
|
||||
|
|
@ -353,10 +353,11 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
|
|||
} else {
|
||||
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt)
|
||||
}
|
||||
senders, notSenders := getSendersRoomEventFilter(eventFilter)
|
||||
rows, err := stmt.QueryContext(
|
||||
ctx, roomID, r.Low(), r.High(),
|
||||
pq.StringArray(eventFilter.Senders),
|
||||
pq.StringArray(eventFilter.NotSenders),
|
||||
pq.StringArray(senders),
|
||||
pq.StringArray(notSenders),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)),
|
||||
eventFilter.Limit+1,
|
||||
|
|
@ -398,11 +399,12 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
|
|||
ctx context.Context, txn *sql.Tx,
|
||||
roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
|
||||
) ([]types.StreamEvent, error) {
|
||||
senders, notSenders := getSendersRoomEventFilter(eventFilter)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectEarlyEventsStmt)
|
||||
rows, err := stmt.QueryContext(
|
||||
ctx, roomID, r.Low(), r.High(),
|
||||
pq.StringArray(eventFilter.Senders),
|
||||
pq.StringArray(eventFilter.NotSenders),
|
||||
pq.StringArray(senders),
|
||||
pq.StringArray(notSenders),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)),
|
||||
eventFilter.Limit,
|
||||
|
|
@ -427,7 +429,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
|
|||
// selectEvents returns the events for the given event IDs. If an event is
|
||||
// missing from the database, it will be omitted.
|
||||
func (s *outputRoomEventsStatements) SelectEvents(
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool,
|
||||
) ([]types.StreamEvent, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectEventsStmt)
|
||||
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||
|
|
@ -435,7 +437,25 @@ func (s *outputRoomEventsStatements) SelectEvents(
|
|||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
|
||||
return rowsToStreamEvents(rows)
|
||||
streamEvents, err := rowsToStreamEvents(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if preserveOrder {
|
||||
eventMap := make(map[string]types.StreamEvent)
|
||||
for _, ev := range streamEvents {
|
||||
eventMap[ev.EventID()] = ev
|
||||
}
|
||||
var returnEvents []types.StreamEvent
|
||||
for _, eventID := range eventIDs {
|
||||
ev, ok := eventMap[eventID]
|
||||
if ok {
|
||||
returnEvents = append(returnEvents, ev)
|
||||
}
|
||||
}
|
||||
return returnEvents, nil
|
||||
}
|
||||
return streamEvents, nil
|
||||
}
|
||||
|
||||
func (s *outputRoomEventsStatements) DeleteEventsForRoom(
|
||||
|
|
@ -462,10 +482,11 @@ func (s *outputRoomEventsStatements) SelectContextEvent(ctx context.Context, txn
|
|||
func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
|
||||
ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter,
|
||||
) (evts []*gomatrixserverlib.HeaderedEvent, err error) {
|
||||
senders, notSenders := getSendersRoomEventFilter(filter)
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectContextBeforeEventStmt).QueryContext(
|
||||
ctx, roomID, id, filter.Limit,
|
||||
pq.StringArray(filter.Senders),
|
||||
pq.StringArray(filter.NotSenders),
|
||||
pq.StringArray(senders),
|
||||
pq.StringArray(notSenders),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
|
||||
)
|
||||
|
|
@ -494,10 +515,11 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
|
|||
func (s *outputRoomEventsStatements) SelectContextAfterEvent(
|
||||
ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter,
|
||||
) (lastID int, evts []*gomatrixserverlib.HeaderedEvent, err error) {
|
||||
senders, notSenders := getSendersRoomEventFilter(filter)
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectContextAfterEventStmt).QueryContext(
|
||||
ctx, roomID, id, filter.Limit,
|
||||
pq.StringArray(filter.Senders),
|
||||
pq.StringArray(filter.NotSenders),
|
||||
pq.StringArray(senders),
|
||||
pq.StringArray(notSenders),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -73,9 +73,6 @@ const selectMaxPositionInTopologySQL = "" +
|
|||
"SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" +
|
||||
") ORDER BY stream_position DESC LIMIT 1"
|
||||
|
||||
const deleteTopologyForRoomSQL = "" +
|
||||
"DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1"
|
||||
|
||||
const selectStreamToTopologicalPositionAscSQL = "" +
|
||||
"SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;"
|
||||
|
||||
|
|
@ -88,7 +85,6 @@ type outputRoomEventsTopologyStatements struct {
|
|||
selectEventIDsInRangeDESCStmt *sql.Stmt
|
||||
selectPositionInTopologyStmt *sql.Stmt
|
||||
selectMaxPositionInTopologyStmt *sql.Stmt
|
||||
deleteTopologyForRoomStmt *sql.Stmt
|
||||
selectStreamToTopologicalPositionAscStmt *sql.Stmt
|
||||
selectStreamToTopologicalPositionDescStmt *sql.Stmt
|
||||
}
|
||||
|
|
@ -114,9 +110,6 @@ func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) {
|
|||
if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteTopologyForRoomStmt, err = db.Prepare(deleteTopologyForRoomSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -148,9 +141,9 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
|
|||
// is requested or not.
|
||||
var stmt *sql.Stmt
|
||||
if chronologicalOrder {
|
||||
stmt = s.selectEventIDsInRangeASCStmt
|
||||
stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeASCStmt)
|
||||
} else {
|
||||
stmt = s.selectEventIDsInRangeDESCStmt
|
||||
stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeDESCStmt)
|
||||
}
|
||||
|
||||
// Query the event IDs.
|
||||
|
|
@ -203,10 +196,3 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
|
|||
err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID)
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -90,6 +90,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ignores, err := NewPostgresIgnoresTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
presence, err := NewPostgresPresenceTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -115,6 +119,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
|
|||
Receipts: receipts,
|
||||
Memberships: memberships,
|
||||
NotificationData: notificationData,
|
||||
Ignores: ignores,
|
||||
Presence: presence,
|
||||
}
|
||||
return &d, nil
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ type Database struct {
|
|||
Receipts tables.Receipts
|
||||
Memberships tables.Memberships
|
||||
NotificationData tables.NotificationData
|
||||
Ignores tables.Ignores
|
||||
Presence tables.Presence
|
||||
}
|
||||
|
||||
|
|
@ -149,7 +150,7 @@ func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, stre
|
|||
// Returns an error if there was a problem talking with the database.
|
||||
// Does not include any transaction IDs in the returned events.
|
||||
func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||
streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs)
|
||||
streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -311,7 +312,7 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e
|
|||
|
||||
// Check if we have all of the event's previous events. If an event is
|
||||
// missing, add it to the room's backward extremities.
|
||||
prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs())
|
||||
prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -456,7 +457,7 @@ func (d *Database) GetEventsInTopologicalRange(
|
|||
}
|
||||
|
||||
// Retrieve the events' contents using their IDs.
|
||||
events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs)
|
||||
events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs, true)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -618,7 +619,7 @@ func (d *Database) fetchMissingStateEvents(
|
|||
) ([]types.StreamEvent, error) {
|
||||
// Fetch from the events table first so we pick up the stream ID for the
|
||||
// event.
|
||||
events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs)
|
||||
events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -1004,6 +1005,14 @@ func (s *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID s
|
|||
return s.OutputEvents.SelectContextAfterEvent(ctx, nil, id, roomID, filter)
|
||||
}
|
||||
|
||||
func (s *Database) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) {
|
||||
return s.Ignores.SelectIgnores(ctx, userID)
|
||||
}
|
||||
|
||||
func (s *Database) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error {
|
||||
return s.Ignores.UpsertIgnores(ctx, userID, ignores)
|
||||
}
|
||||
|
||||
func (s *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) {
|
||||
return s.Presence.UpsertPresence(ctx, nil, userID, statusMsg, presence, lastActiveTS, fromSync)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -41,23 +41,23 @@ const insertAccountDataSQL = "" +
|
|||
" ON CONFLICT (user_id, room_id, type) DO UPDATE" +
|
||||
" SET id = $5"
|
||||
|
||||
// further parameters are added by prepareWithFilters
|
||||
const selectAccountDataInRangeSQL = "" +
|
||||
"SELECT room_id, type FROM syncapi_account_data_type" +
|
||||
" WHERE user_id = $1 AND id > $2 AND id <= $3" +
|
||||
" ORDER BY id ASC"
|
||||
" WHERE user_id = $1 AND id > $2 AND id <= $3"
|
||||
|
||||
const selectMaxAccountDataIDSQL = "" +
|
||||
"SELECT MAX(id) FROM syncapi_account_data_type"
|
||||
|
||||
type accountDataStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
insertAccountDataStmt *sql.Stmt
|
||||
selectMaxAccountDataIDStmt *sql.Stmt
|
||||
selectAccountDataInRangeStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) {
|
||||
func NewSqliteAccountDataTable(db *sql.DB, streamID *StreamIDStatements) (tables.AccountData, error) {
|
||||
s := &accountDataStatements{
|
||||
db: db,
|
||||
streamIDStatements: streamID,
|
||||
|
|
@ -94,18 +94,25 @@ func (s *accountDataStatements) SelectAccountDataInRange(
|
|||
ctx context.Context,
|
||||
userID string,
|
||||
r types.Range,
|
||||
accountDataFilterPart *gomatrixserverlib.EventFilter,
|
||||
filter *gomatrixserverlib.EventFilter,
|
||||
) (data map[string][]string, err error) {
|
||||
data = make(map[string][]string)
|
||||
stmt, params, err := prepareWithFilters(
|
||||
s.db, nil, selectAccountDataInRangeSQL,
|
||||
[]interface{}{
|
||||
userID, r.Low(), r.High(),
|
||||
},
|
||||
filter.Senders, filter.NotSenders,
|
||||
filter.Types, filter.NotTypes,
|
||||
[]string{}, filter.Limit, FilterOrderAsc,
|
||||
)
|
||||
|
||||
rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High())
|
||||
rows, err := stmt.QueryContext(ctx, params...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed")
|
||||
|
||||
var entries int
|
||||
|
||||
for rows.Next() {
|
||||
var dataType string
|
||||
var roomID string
|
||||
|
|
@ -114,31 +121,11 @@ func (s *accountDataStatements) SelectAccountDataInRange(
|
|||
return
|
||||
}
|
||||
|
||||
// check if we should add this by looking at the filter.
|
||||
// It would be nice if we could do this in SQL-land, but the mix of variadic
|
||||
// and positional parameters makes the query annoyingly hard to do, it's easier
|
||||
// and clearer to do it in Go-land. If there are no filters for [not]types then
|
||||
// this gets skipped.
|
||||
for _, includeType := range accountDataFilterPart.Types {
|
||||
if includeType != dataType { // TODO: wildcard support
|
||||
continue
|
||||
}
|
||||
}
|
||||
for _, excludeType := range accountDataFilterPart.NotTypes {
|
||||
if excludeType == dataType { // TODO: wildcard support
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if len(data[roomID]) > 0 {
|
||||
data[roomID] = append(data[roomID], dataType)
|
||||
} else {
|
||||
data[roomID] = []string{dataType}
|
||||
}
|
||||
entries++
|
||||
if entries >= accountDataFilterPart.Limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return data, nil
|
||||
|
|
|
|||
|
|
@ -47,15 +47,11 @@ const selectBackwardExtremitiesForRoomSQL = "" +
|
|||
const deleteBackwardExtremitySQL = "" +
|
||||
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
|
||||
|
||||
const deleteBackwardExtremitiesForRoomSQL = "" +
|
||||
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1"
|
||||
|
||||
type backwardExtremitiesStatements struct {
|
||||
db *sql.DB
|
||||
insertBackwardExtremityStmt *sql.Stmt
|
||||
selectBackwardExtremitiesForRoomStmt *sql.Stmt
|
||||
deleteBackwardExtremityStmt *sql.Stmt
|
||||
deleteBackwardExtremitiesForRoomStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
|
||||
|
|
@ -75,9 +71,6 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities
|
|||
if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
|
|
@ -116,10 +109,3 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
|
|||
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID)
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ const selectEventsWithEventIDsSQL = "" +
|
|||
|
||||
type currentRoomStateStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
upsertRoomStateStmt *sql.Stmt
|
||||
deleteRoomStateByEventIDStmt *sql.Stmt
|
||||
deleteRoomStateForRoomStmt *sql.Stmt
|
||||
|
|
@ -100,7 +100,7 @@ type currentRoomStateStatements struct {
|
|||
selectStateEventStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
|
||||
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
|
||||
s := ¤tRoomStateStatements{
|
||||
db: db,
|
||||
streamIDStatements: streamID,
|
||||
|
|
|
|||
|
|
@ -25,32 +25,48 @@ const (
|
|||
// parts.
|
||||
func prepareWithFilters(
|
||||
db *sql.DB, txn *sql.Tx, query string, params []interface{},
|
||||
senders, notsenders, types, nottypes []string, excludeEventIDs []string,
|
||||
senders, notsenders, types, nottypes *[]string, excludeEventIDs []string,
|
||||
limit int, order FilterOrder,
|
||||
) (*sql.Stmt, []interface{}, error) {
|
||||
offset := len(params)
|
||||
if count := len(senders); count > 0 {
|
||||
query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range senders {
|
||||
params, offset = append(params, v), offset+1
|
||||
if senders != nil {
|
||||
if count := len(*senders); count > 0 {
|
||||
query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range *senders {
|
||||
params, offset = append(params, v), offset+1
|
||||
}
|
||||
} else {
|
||||
query += ` AND sender = ""`
|
||||
}
|
||||
}
|
||||
if count := len(notsenders); count > 0 {
|
||||
query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range notsenders {
|
||||
params, offset = append(params, v), offset+1
|
||||
if notsenders != nil {
|
||||
if count := len(*notsenders); count > 0 {
|
||||
query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range *notsenders {
|
||||
params, offset = append(params, v), offset+1
|
||||
}
|
||||
} else {
|
||||
query += ` AND sender NOT = ""`
|
||||
}
|
||||
}
|
||||
if count := len(types); count > 0 {
|
||||
query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range types {
|
||||
params, offset = append(params, v), offset+1
|
||||
if types != nil {
|
||||
if count := len(*types); count > 0 {
|
||||
query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range *types {
|
||||
params, offset = append(params, v), offset+1
|
||||
}
|
||||
} else {
|
||||
query += ` AND type = ""`
|
||||
}
|
||||
}
|
||||
if count := len(nottypes); count > 0 {
|
||||
query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range nottypes {
|
||||
params, offset = append(params, v), offset+1
|
||||
if nottypes != nil {
|
||||
if count := len(*nottypes); count > 0 {
|
||||
query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range *nottypes {
|
||||
params, offset = append(params, v), offset+1
|
||||
}
|
||||
} else {
|
||||
query += ` AND type NOT = ""`
|
||||
}
|
||||
}
|
||||
if count := len(excludeEventIDs); count > 0 {
|
||||
|
|
|
|||
87
syncapi/storage/sqlite3/ignores_table.go
Normal file
87
syncapi/storage/sqlite3/ignores_table.go
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
)
|
||||
|
||||
const ignoresSchema = `
|
||||
-- Stores data about ignoress
|
||||
CREATE TABLE IF NOT EXISTS syncapi_ignores (
|
||||
-- The user ID whose ignore list this belongs to.
|
||||
user_id TEXT NOT NULL,
|
||||
ignores_json TEXT NOT NULL,
|
||||
PRIMARY KEY(user_id)
|
||||
);
|
||||
`
|
||||
|
||||
const selectIgnoresSQL = "" +
|
||||
"SELECT ignores_json FROM syncapi_ignores WHERE user_id = $1"
|
||||
|
||||
const upsertIgnoresSQL = "" +
|
||||
"INSERT INTO syncapi_ignores (user_id, ignores_json) VALUES ($1, $2)" +
|
||||
" ON CONFLICT DO UPDATE set ignores_json = $2"
|
||||
|
||||
type ignoresStatements struct {
|
||||
selectIgnoresStmt *sql.Stmt
|
||||
upsertIgnoresStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteIgnoresTable(db *sql.DB) (tables.Ignores, error) {
|
||||
_, err := db.Exec(ignoresSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := &ignoresStatements{}
|
||||
if s.selectIgnoresStmt, err = db.Prepare(selectIgnoresSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.upsertIgnoresStmt, err = db.Prepare(upsertIgnoresSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *ignoresStatements) SelectIgnores(
|
||||
ctx context.Context, userID string,
|
||||
) (*types.IgnoredUsers, error) {
|
||||
var ignoresData []byte
|
||||
err := s.selectIgnoresStmt.QueryRowContext(ctx, userID).Scan(&ignoresData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var ignores types.IgnoredUsers
|
||||
if err = json.Unmarshal(ignoresData, &ignores); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ignores, nil
|
||||
}
|
||||
|
||||
func (s *ignoresStatements) UpsertIgnores(
|
||||
ctx context.Context, userID string, ignores *types.IgnoredUsers,
|
||||
) error {
|
||||
ignoresJSON, err := json.Marshal(ignores)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.upsertIgnoresStmt.ExecContext(ctx, userID, ignoresJSON)
|
||||
return err
|
||||
}
|
||||
|
|
@ -59,14 +59,14 @@ const selectMaxInviteIDSQL = "" +
|
|||
|
||||
type inviteEventsStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
insertInviteEventStmt *sql.Stmt
|
||||
selectInviteEventsInRangeStmt *sql.Stmt
|
||||
deleteInviteEventStmt *sql.Stmt
|
||||
selectMaxInviteIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) {
|
||||
func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Invites, error) {
|
||||
s := &inviteEventsStatements{
|
||||
db: db,
|
||||
streamIDStatements: streamID,
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
|
|
@ -57,12 +56,6 @@ const upsertMembershipSQL = "" +
|
|||
" ON CONFLICT (room_id, user_id, membership)" +
|
||||
" DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6"
|
||||
|
||||
const selectMembershipSQL = "" +
|
||||
"SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" +
|
||||
" WHERE room_id = $1 AND user_id = $2 AND membership IN ($3)" +
|
||||
" ORDER BY stream_pos DESC" +
|
||||
" LIMIT 1"
|
||||
|
||||
const selectMembershipCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM (" +
|
||||
" SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" +
|
||||
|
|
@ -111,22 +104,6 @@ func (s *membershipsStatements) UpsertMembership(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *membershipsStatements) SelectMembership(
|
||||
ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string,
|
||||
) (eventID string, streamPos, topologyPos types.StreamPosition, err error) {
|
||||
params := []interface{}{roomID, userID}
|
||||
for _, membership := range memberships {
|
||||
params = append(params, membership)
|
||||
}
|
||||
orig := strings.Replace(selectMembershipSQL, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1)
|
||||
stmt, err := s.db.Prepare(orig)
|
||||
if err != nil {
|
||||
return "", 0, 0, err
|
||||
}
|
||||
err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *membershipsStatements) SelectMembershipCount(
|
||||
ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition,
|
||||
) (count int, err error) {
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ const insertEventSQL = "" +
|
|||
"ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)"
|
||||
|
||||
const selectEventsSQL = "" +
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1"
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id IN ($1)"
|
||||
|
||||
const selectRecentEventsSQL = "" +
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
|
||||
|
|
@ -111,9 +111,8 @@ const selectContextAfterEventSQL = "" +
|
|||
|
||||
type outputRoomEventsStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
insertEventStmt *sql.Stmt
|
||||
selectEventsStmt *sql.Stmt
|
||||
selectMaxEventIDStmt *sql.Stmt
|
||||
updateEventJSONStmt *sql.Stmt
|
||||
deleteEventsForRoomStmt *sql.Stmt
|
||||
|
|
@ -122,7 +121,7 @@ type outputRoomEventsStatements struct {
|
|||
selectContextAfterEventStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
|
||||
func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Events, error) {
|
||||
s := &outputRoomEventsStatements{
|
||||
db: db,
|
||||
streamIDStatements: streamID,
|
||||
|
|
@ -133,7 +132,6 @@ func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Even
|
|||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertEventStmt, insertEventSQL},
|
||||
{&s.selectEventsStmt, selectEventsSQL},
|
||||
{&s.selectMaxEventIDStmt, selectMaxEventIDSQL},
|
||||
{&s.updateEventJSONStmt, updateEventJSONSQL},
|
||||
{&s.deleteEventsForRoomStmt, deleteEventsForRoomSQL},
|
||||
|
|
@ -421,21 +419,43 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
|
|||
// selectEvents returns the events for the given event IDs. If an event is
|
||||
// missing from the database, it will be omitted.
|
||||
func (s *outputRoomEventsStatements) SelectEvents(
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool,
|
||||
) ([]types.StreamEvent, error) {
|
||||
var returnEvents []types.StreamEvent
|
||||
stmt := sqlutil.TxStmt(txn, s.selectEventsStmt)
|
||||
for _, eventID := range eventIDs {
|
||||
rows, err := stmt.QueryContext(ctx, eventID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if streamEvents, err := rowsToStreamEvents(rows); err == nil {
|
||||
returnEvents = append(returnEvents, streamEvents...)
|
||||
}
|
||||
internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
|
||||
iEventIDs := make([]interface{}, len(eventIDs))
|
||||
for i := range eventIDs {
|
||||
iEventIDs[i] = eventIDs[i]
|
||||
}
|
||||
return returnEvents, nil
|
||||
selectSQL := strings.Replace(selectEventsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1)
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if txn != nil {
|
||||
rows, err = txn.QueryContext(ctx, selectSQL, iEventIDs...)
|
||||
} else {
|
||||
rows, err = s.db.QueryContext(ctx, selectSQL, iEventIDs...)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
|
||||
streamEvents, err := rowsToStreamEvents(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if preserveOrder {
|
||||
var returnEvents []types.StreamEvent
|
||||
eventMap := make(map[string]types.StreamEvent)
|
||||
for _, ev := range streamEvents {
|
||||
eventMap[ev.EventID()] = ev
|
||||
}
|
||||
for _, eventID := range eventIDs {
|
||||
ev, ok := eventMap[eventID]
|
||||
if ok {
|
||||
returnEvents = append(returnEvents, ev)
|
||||
}
|
||||
}
|
||||
return returnEvents, nil
|
||||
}
|
||||
return streamEvents, nil
|
||||
}
|
||||
|
||||
func (s *outputRoomEventsStatements) DeleteEventsForRoom(
|
||||
|
|
|
|||
|
|
@ -78,7 +78,6 @@ type outputRoomEventsTopologyStatements struct {
|
|||
selectEventIDsInRangeDESCStmt *sql.Stmt
|
||||
selectPositionInTopologyStmt *sql.Stmt
|
||||
selectMaxPositionInTopologyStmt *sql.Stmt
|
||||
deleteTopologyForRoomStmt *sql.Stmt
|
||||
selectStreamToTopologicalPositionAscStmt *sql.Stmt
|
||||
selectStreamToTopologicalPositionDescStmt *sql.Stmt
|
||||
}
|
||||
|
|
@ -191,10 +190,3 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
|
|||
err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID)
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ const selectMaxPeekIDSQL = "" +
|
|||
|
||||
type peekStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
insertPeekStmt *sql.Stmt
|
||||
deletePeekStmt *sql.Stmt
|
||||
deletePeeksStmt *sql.Stmt
|
||||
|
|
@ -75,7 +75,7 @@ type peekStatements struct {
|
|||
selectMaxPeekIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks, error) {
|
||||
func NewSqlitePeeksTable(db *sql.DB, streamID *StreamIDStatements) (tables.Peeks, error) {
|
||||
_, err := db.Exec(peeksSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ const selectPresenceAfter = "" +
|
|||
|
||||
type presenceStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
upsertPresenceStmt *sql.Stmt
|
||||
upsertPresenceFromSyncStmt *sql.Stmt
|
||||
selectPresenceForUsersStmt *sql.Stmt
|
||||
|
|
@ -83,7 +83,7 @@ type presenceStatements struct {
|
|||
selectPresenceAfterStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqlitePresenceTable(db *sql.DB, streamID *streamIDStatements) (*presenceStatements, error) {
|
||||
func NewSqlitePresenceTable(db *sql.DB, streamID *StreamIDStatements) (*presenceStatements, error) {
|
||||
_, err := db.Exec(presenceSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -59,13 +59,13 @@ const selectMaxReceiptIDSQL = "" +
|
|||
|
||||
type receiptStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
upsertReceipt *sql.Stmt
|
||||
selectRoomReceipts *sql.Stmt
|
||||
selectMaxReceiptID *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Receipts, error) {
|
||||
func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Receipts, error) {
|
||||
_, err := db.Exec(receiptsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -32,12 +32,12 @@ const increaseStreamIDStmt = "" +
|
|||
"UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" +
|
||||
" RETURNING stream_id"
|
||||
|
||||
type streamIDStatements struct {
|
||||
type StreamIDStatements struct {
|
||||
db *sql.DB
|
||||
increaseStreamIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
|
||||
func (s *StreamIDStatements) Prepare(db *sql.DB) (err error) {
|
||||
s.db = db
|
||||
_, err = db.Exec(streamIDTableSchema)
|
||||
if err != nil {
|
||||
|
|
@ -49,31 +49,31 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
func (s *StreamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||
err = increaseStmt.QueryRowContext(ctx, "global").Scan(&pos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
func (s *StreamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||
err = increaseStmt.QueryRowContext(ctx, "receipt").Scan(&pos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
func (s *StreamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||
err = increaseStmt.QueryRowContext(ctx, "invite").Scan(&pos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
func (s *StreamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||
err = increaseStmt.QueryRowContext(ctx, "accountdata").Scan(&pos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
func (s *StreamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||
err = increaseStmt.QueryRowContext(ctx, "presence").Scan(&pos)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ type SyncServerDatasource struct {
|
|||
shared.Database
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
streamID streamIDStatements
|
||||
streamID StreamIDStatements
|
||||
}
|
||||
|
||||
// NewDatabase creates a new sync server database
|
||||
|
|
@ -49,7 +49,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
|
|||
}
|
||||
|
||||
func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) {
|
||||
if err = d.streamID.prepare(d.db); err != nil {
|
||||
if err = d.streamID.Prepare(d.db); err != nil {
|
||||
return err
|
||||
}
|
||||
accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID)
|
||||
|
|
@ -100,6 +100,10 @@ func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (er
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ignores, err := NewSqliteIgnoresTable(d.db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
presence, err := NewSqlitePresenceTable(d.db, &d.streamID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -125,6 +129,7 @@ func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (er
|
|||
Receipts: receipts,
|
||||
Memberships: memberships,
|
||||
NotificationData: notificationData,
|
||||
Ignores: ignores,
|
||||
Presence: presence,
|
||||
}
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -1,121 +1,29 @@
|
|||
package storage_test
|
||||
|
||||
// TODO: Fix these tests
|
||||
/*
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var (
|
||||
ctx = context.Background()
|
||||
emptyStateKey = ""
|
||||
testOrigin = gomatrixserverlib.ServerName("hollow.knight")
|
||||
testRoomID = fmt.Sprintf("!hallownest:%s", testOrigin)
|
||||
testUserIDA = fmt.Sprintf("@hornet:%s", testOrigin)
|
||||
testUserIDB = fmt.Sprintf("@paleking:%s", testOrigin)
|
||||
testUserDeviceA = userapi.Device{
|
||||
UserID: testUserIDA,
|
||||
ID: "device_id_A",
|
||||
DisplayName: "Device A",
|
||||
}
|
||||
testRoomVersion = gomatrixserverlib.RoomVersionV4
|
||||
testKeyID = gomatrixserverlib.KeyID("ed25519:storage_test")
|
||||
testPrivateKey = ed25519.NewKeyFromSeed([]byte{
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
||||
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
|
||||
})
|
||||
)
|
||||
var ctx = context.Background()
|
||||
|
||||
func MustCreateEvent(t *testing.T, roomID string, prevs []*gomatrixserverlib.HeaderedEvent, b *gomatrixserverlib.EventBuilder) *gomatrixserverlib.HeaderedEvent {
|
||||
b.RoomID = roomID
|
||||
if prevs != nil {
|
||||
prevIDs := make([]string, len(prevs))
|
||||
for i := range prevs {
|
||||
prevIDs[i] = prevs[i].EventID()
|
||||
}
|
||||
b.PrevEvents = prevIDs
|
||||
}
|
||||
e, err := b.Build(time.Now(), testOrigin, testKeyID, testPrivateKey, testRoomVersion)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to build event: %s", err)
|
||||
}
|
||||
return e.Headered(testRoomVersion)
|
||||
}
|
||||
|
||||
func MustCreateDatabase(t *testing.T) storage.Database {
|
||||
dbname := fmt.Sprintf("test_%s.db", t.Name())
|
||||
if _, err := os.Stat(dbname); err == nil {
|
||||
if err = os.Remove(dbname); err != nil {
|
||||
t.Fatalf("tried to delete stale test database but failed: %s", err)
|
||||
}
|
||||
}
|
||||
db, err := sqlite3.NewDatabase(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(fmt.Sprintf("file:%s", dbname)),
|
||||
func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := storage.NewSyncServerDatasource(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewSyncServerDatasource returned %s", err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// Create a list of events which include a create event, join event and some messages.
|
||||
func SimpleRoom(t *testing.T, roomID, userA, userB string) (msgs []*gomatrixserverlib.HeaderedEvent, state []*gomatrixserverlib.HeaderedEvent) {
|
||||
var events []*gomatrixserverlib.HeaderedEvent
|
||||
events = append(events, MustCreateEvent(t, roomID, nil, &gomatrixserverlib.EventBuilder{
|
||||
Content: []byte(fmt.Sprintf(`{"room_version":"4","creator":"%s"}`, userA)),
|
||||
Type: "m.room.create",
|
||||
StateKey: &emptyStateKey,
|
||||
Sender: userA,
|
||||
Depth: int64(len(events) + 1),
|
||||
}))
|
||||
state = append(state, events[len(events)-1])
|
||||
events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{
|
||||
Content: []byte(`{"membership":"join"}`),
|
||||
Type: "m.room.member",
|
||||
StateKey: &userA,
|
||||
Sender: userA,
|
||||
Depth: int64(len(events) + 1),
|
||||
}))
|
||||
state = append(state, events[len(events)-1])
|
||||
for i := 0; i < 10; i++ {
|
||||
events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{
|
||||
Content: []byte(fmt.Sprintf(`{"body":"Message A %d"}`, i+1)),
|
||||
Type: "m.room.message",
|
||||
Sender: userA,
|
||||
Depth: int64(len(events) + 1),
|
||||
}))
|
||||
}
|
||||
events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{
|
||||
Content: []byte(`{"membership":"join"}`),
|
||||
Type: "m.room.member",
|
||||
StateKey: &userB,
|
||||
Sender: userB,
|
||||
Depth: int64(len(events) + 1),
|
||||
}))
|
||||
state = append(state, events[len(events)-1])
|
||||
for i := 0; i < 10; i++ {
|
||||
events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{
|
||||
Content: []byte(fmt.Sprintf(`{"body":"Message B %d"}`, i+1)),
|
||||
Type: "m.room.message",
|
||||
Sender: userB,
|
||||
Depth: int64(len(events) + 1),
|
||||
}))
|
||||
}
|
||||
|
||||
return events, state
|
||||
return db, close
|
||||
}
|
||||
|
||||
func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (positions []types.StreamPosition) {
|
||||
|
|
@ -131,206 +39,157 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
|
|||
if err != nil {
|
||||
t.Fatalf("WriteEvent failed: %s", err)
|
||||
}
|
||||
fmt.Println("Event ID", ev.EventID(), " spos=", pos, "depth=", ev.Depth())
|
||||
t.Logf("Event ID %s spos=%v depth=%v", ev.EventID(), pos, ev.Depth())
|
||||
positions = append(positions, pos)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func TestWriteEvents(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := MustCreateDatabase(t)
|
||||
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
|
||||
MustWriteEvents(t, db, events)
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
alice := test.NewUser()
|
||||
r := test.NewRoom(t, alice)
|
||||
db, close := MustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
MustWriteEvents(t, db, r.Events())
|
||||
})
|
||||
}
|
||||
|
||||
// These tests assert basic functionality of the IncrementalSync and CompleteSync functions.
|
||||
func TestSyncResponse(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := MustCreateDatabase(t)
|
||||
events, state := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
|
||||
positions := MustWriteEvents(t, db, events)
|
||||
latest, err := db.SyncPosition(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get SyncPosition: %s", err)
|
||||
}
|
||||
// These tests assert basic functionality of RecentEvents for PDUs
|
||||
func TestRecentEventsPDU(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := MustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
alice := test.NewUser()
|
||||
// dummy room to make sure SQL queries are filtering on room ID
|
||||
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
|
||||
|
||||
testCases := []struct {
|
||||
Name string
|
||||
DoSync func() (*types.Response, error)
|
||||
WantTimeline []*gomatrixserverlib.HeaderedEvent
|
||||
WantState []*gomatrixserverlib.HeaderedEvent
|
||||
}{
|
||||
// The purpose of this test is to make sure that incremental syncs are including up to the latest events.
|
||||
// It's a basic sanity test that sync works. It creates a `since` token that is on the penultimate event.
|
||||
// It makes sure the response includes the final event.
|
||||
{
|
||||
Name: "IncrementalSync penultimate",
|
||||
DoSync: func() (*types.Response, error) {
|
||||
from := types.StreamingToken{ // pretend we are at the penultimate event
|
||||
PDUPosition: positions[len(positions)-2],
|
||||
// actual test room
|
||||
r := test.NewRoom(t, alice)
|
||||
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hi"})
|
||||
events := r.Events()
|
||||
positions := MustWriteEvents(t, db, events)
|
||||
|
||||
// dummy room to make sure SQL queries are filtering on room ID
|
||||
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
|
||||
|
||||
latest, err := db.MaxStreamPositionForPDUs(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
Name string
|
||||
From types.StreamPosition
|
||||
To types.StreamPosition
|
||||
Limit int
|
||||
ReverseOrder bool
|
||||
WantEvents []*gomatrixserverlib.HeaderedEvent
|
||||
WantLimited bool
|
||||
}{
|
||||
// The purpose of this test is to make sure that incremental syncs are including up to the latest events.
|
||||
// It's a basic sanity test that sync works. It creates a streaming position that is on the penultimate event.
|
||||
// It makes sure the response includes the final event.
|
||||
{
|
||||
Name: "penultimate",
|
||||
From: positions[len(positions)-2], // pretend we are at the penultimate event
|
||||
To: latest,
|
||||
Limit: 100,
|
||||
WantEvents: events[len(events)-1:],
|
||||
WantLimited: false,
|
||||
},
|
||||
// The purpose of this test is to check that limits can be applied and work.
|
||||
// This is critical for big rooms hence the test here.
|
||||
{
|
||||
Name: "limited",
|
||||
From: 0,
|
||||
To: latest,
|
||||
Limit: 1,
|
||||
WantEvents: events[len(events)-1:],
|
||||
WantLimited: true,
|
||||
},
|
||||
// The purpose of this test is to check that we can return every event with a high
|
||||
// enough limit
|
||||
{
|
||||
Name: "large limited",
|
||||
From: 0,
|
||||
To: latest,
|
||||
Limit: 100,
|
||||
WantEvents: events,
|
||||
WantLimited: false,
|
||||
},
|
||||
// The purpose of this test is to check that we can return events in reverse order
|
||||
{
|
||||
Name: "reverse",
|
||||
From: positions[len(positions)-3], // 2 events back
|
||||
To: latest,
|
||||
Limit: 100,
|
||||
ReverseOrder: true,
|
||||
WantEvents: test.Reversed(events[len(events)-2:]),
|
||||
WantLimited: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i := range testCases {
|
||||
tc := testCases[i]
|
||||
t.Run(tc.Name, func(st *testing.T) {
|
||||
var filter gomatrixserverlib.RoomEventFilter
|
||||
filter.Limit = tc.Limit
|
||||
gotEvents, limited, err := db.RecentEvents(ctx, r.ID, types.Range{
|
||||
From: tc.From,
|
||||
To: tc.To,
|
||||
}, &filter, !tc.ReverseOrder, true)
|
||||
if err != nil {
|
||||
st.Fatalf("failed to do sync: %s", err)
|
||||
}
|
||||
res := types.NewResponse()
|
||||
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
|
||||
},
|
||||
WantTimeline: events[len(events)-1:],
|
||||
},
|
||||
// The purpose of this test is to check that passing a `numRecentEventsPerRoom` correctly limits the
|
||||
// number of returned events. This is critical for big rooms hence the test here.
|
||||
{
|
||||
Name: "IncrementalSync limited",
|
||||
DoSync: func() (*types.Response, error) {
|
||||
from := types.StreamingToken{ // pretend we are 10 events behind
|
||||
PDUPosition: positions[len(positions)-11],
|
||||
if limited != tc.WantLimited {
|
||||
st.Errorf("got limited=%v want %v", limited, tc.WantLimited)
|
||||
}
|
||||
res := types.NewResponse()
|
||||
// limit is set to 5
|
||||
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
|
||||
},
|
||||
// want the last 5 events, NOT the last 10.
|
||||
WantTimeline: events[len(events)-5:],
|
||||
},
|
||||
// The purpose of this test is to check that CompleteSync returns all the current state as well as
|
||||
// honouring the `numRecentEventsPerRoom` value
|
||||
{
|
||||
Name: "CompleteSync limited",
|
||||
DoSync: func() (*types.Response, error) {
|
||||
res := types.NewResponse()
|
||||
// limit set to 5
|
||||
return db.CompleteSync(ctx, res, testUserDeviceA, 5)
|
||||
},
|
||||
// want the last 5 events
|
||||
WantTimeline: events[len(events)-5:],
|
||||
// want all state for the room
|
||||
WantState: state,
|
||||
},
|
||||
// The purpose of this test is to check that CompleteSync can return everything with a high enough
|
||||
// `numRecentEventsPerRoom`.
|
||||
{
|
||||
Name: "CompleteSync",
|
||||
DoSync: func() (*types.Response, error) {
|
||||
res := types.NewResponse()
|
||||
return db.CompleteSync(ctx, res, testUserDeviceA, len(events)+1)
|
||||
},
|
||||
WantTimeline: events,
|
||||
// We want no state at all as that field in /sync is the delta between the token (beginning of time)
|
||||
// and the START of the timeline.
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.Name, func(st *testing.T) {
|
||||
res, err := tc.DoSync()
|
||||
if err != nil {
|
||||
st.Fatalf("failed to do sync: %s", err)
|
||||
}
|
||||
next := types.StreamingToken{
|
||||
PDUPosition: latest.PDUPosition,
|
||||
TypingPosition: latest.TypingPosition,
|
||||
ReceiptPosition: latest.ReceiptPosition,
|
||||
SendToDevicePosition: latest.SendToDevicePosition,
|
||||
}
|
||||
if res.NextBatch.String() != next.String() {
|
||||
st.Errorf("NextBatch got %s want %s", res.NextBatch, next.String())
|
||||
}
|
||||
roomRes, ok := res.Rooms.Join[testRoomID]
|
||||
if !ok {
|
||||
st.Fatalf("IncrementalSync response missing room %s - response: %+v", testRoomID, res)
|
||||
}
|
||||
assertEventsEqual(st, "state for "+testRoomID, false, roomRes.State.Events, tc.WantState)
|
||||
assertEventsEqual(st, "timeline for "+testRoomID, false, roomRes.Timeline.Events, tc.WantTimeline)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := MustCreateDatabase(t)
|
||||
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
|
||||
positions := MustWriteEvents(t, db, events)
|
||||
latest, err := db.SyncPosition(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get SyncPosition: %s", err)
|
||||
}
|
||||
from := types.StreamingToken{
|
||||
PDUPosition: positions[len(positions)-2],
|
||||
}
|
||||
|
||||
res := types.NewResponse()
|
||||
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to IncrementalSync with latest token")
|
||||
}
|
||||
roomRes, ok := res.Rooms.Join[testRoomID]
|
||||
if !ok {
|
||||
t.Fatalf("IncrementalSync response missing room %s - response: %+v", testRoomID, res)
|
||||
}
|
||||
// returns the last event "Message 10"
|
||||
assertEventsEqual(t, "IncrementalSync Timeline", false, roomRes.Timeline.Events, reversed(events[len(events)-1:]))
|
||||
|
||||
prev := roomRes.Timeline.PrevBatch.String()
|
||||
if prev == "" {
|
||||
t.Fatalf("IncrementalSync expected prev_batch token")
|
||||
}
|
||||
prevBatchToken, err := types.NewTopologyTokenFromString(prev)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to NewTopologyTokenFromString : %s", err)
|
||||
}
|
||||
// backpaginate 5 messages starting at the latest position.
|
||||
// head towards the beginning of time
|
||||
to := types.TopologyToken{}
|
||||
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &prevBatchToken, &to, testRoomID, 5, true)
|
||||
if err != nil {
|
||||
t.Fatalf("GetEventsInRange returned an error: %s", err)
|
||||
}
|
||||
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
|
||||
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-6:len(events)-1]))
|
||||
}
|
||||
|
||||
// The purpose of this test is to ensure that backfill does indeed go backwards, using a stream token.
|
||||
func TestGetEventsInRangeWithStreamToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := MustCreateDatabase(t)
|
||||
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
|
||||
MustWriteEvents(t, db, events)
|
||||
latest, err := db.SyncPosition(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get SyncPosition: %s", err)
|
||||
}
|
||||
// head towards the beginning of time
|
||||
to := types.StreamingToken{}
|
||||
|
||||
// backpaginate 5 messages starting at the latest position.
|
||||
paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true)
|
||||
if err != nil {
|
||||
t.Fatalf("GetEventsInRange returned an error: %s", err)
|
||||
}
|
||||
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
|
||||
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:]))
|
||||
if len(gotEvents) != len(tc.WantEvents) {
|
||||
st.Errorf("got %d events, want %d", len(gotEvents), len(tc.WantEvents))
|
||||
}
|
||||
for j := range gotEvents {
|
||||
if !reflect.DeepEqual(gotEvents[j].JSON(), tc.WantEvents[j].JSON()) {
|
||||
st.Errorf("event %d got %s want %s", j, string(gotEvents[j].JSON()), string(tc.WantEvents[j].JSON()))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token
|
||||
func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := MustCreateDatabase(t)
|
||||
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
|
||||
MustWriteEvents(t, db, events)
|
||||
from, err := db.MaxTopologicalPosition(ctx, testRoomID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
|
||||
}
|
||||
// head towards the beginning of time
|
||||
to := types.TopologyToken{}
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := MustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
alice := test.NewUser()
|
||||
r := test.NewRoom(t, alice)
|
||||
for i := 0; i < 10; i++ {
|
||||
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)})
|
||||
}
|
||||
events := r.Events()
|
||||
_ = MustWriteEvents(t, db, events)
|
||||
|
||||
// backpaginate 5 messages starting at the latest position.
|
||||
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, testRoomID, 5, true)
|
||||
if err != nil {
|
||||
t.Fatalf("GetEventsInRange returned an error: %s", err)
|
||||
}
|
||||
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
|
||||
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:]))
|
||||
from, err := db.MaxTopologicalPosition(ctx, r.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
|
||||
}
|
||||
t.Logf("max topo pos = %+v", from)
|
||||
// head towards the beginning of time
|
||||
to := types.TopologyToken{}
|
||||
|
||||
// backpaginate 5 messages starting at the latest position.
|
||||
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, 5, true)
|
||||
if err != nil {
|
||||
t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
|
||||
}
|
||||
gots := db.StreamEventsToEvents(nil, paginatedEvents)
|
||||
test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
|
||||
})
|
||||
}
|
||||
|
||||
/*
|
||||
// The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth.
|
||||
// For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent
|
||||
// will appear FIRST when going backwards. This test creates a DAG like:
|
||||
|
|
@ -740,12 +599,4 @@ func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *typ
|
|||
tok.Decrement()
|
||||
return &tok
|
||||
}
|
||||
|
||||
func reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
|
||||
out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
|
||||
for i := 0; i < len(in); i++ {
|
||||
out[i] = in[len(in)-i-1]
|
||||
}
|
||||
return out
|
||||
}
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ type Events interface {
|
|||
SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
|
||||
// SelectEarlyEvents returns the earliest events in the given room.
|
||||
SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error)
|
||||
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error)
|
||||
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool) ([]types.StreamEvent, error)
|
||||
UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error
|
||||
// DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely.
|
||||
DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
|
||||
|
|
@ -84,8 +84,6 @@ type Topology interface {
|
|||
SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error)
|
||||
// SelectMaxPositionInTopology returns the event which has the highest depth, and if there are multiple, the event with the highest stream position.
|
||||
SelectMaxPositionInTopology(ctx context.Context, txn *sql.Tx, roomID string) (depth types.StreamPosition, spos types.StreamPosition, err error)
|
||||
// DeleteTopologyForRoom removes all topological information for a room. This should only be done when removing the room entirely.
|
||||
DeleteTopologyForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
|
||||
// SelectStreamToTopologicalPosition converts a stream position to a topological position by finding the nearest topological position in the room.
|
||||
SelectStreamToTopologicalPosition(ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, forward bool) (topoPos types.StreamPosition, err error)
|
||||
}
|
||||
|
|
@ -132,8 +130,6 @@ type BackwardsExtremities interface {
|
|||
SelectBackwardExtremitiesForRoom(ctx context.Context, roomID string) (bwExtrems map[string][]string, err error)
|
||||
// DeleteBackwardExtremity removes a backwards extremity for a room, if one existed.
|
||||
DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error)
|
||||
// DeleteBackwardExtremitiesFoorRoomID removes all backward extremities for a room. This should only be done when removing the room entirely.
|
||||
DeleteBackwardExtremitiesForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
|
||||
}
|
||||
|
||||
// SendToDevice tracks send-to-device messages which are sent to individual
|
||||
|
|
@ -173,7 +169,6 @@ type Receipts interface {
|
|||
|
||||
type Memberships interface {
|
||||
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
|
||||
SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string) (eventID string, streamPos, topologyPos types.StreamPosition, err error)
|
||||
SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error)
|
||||
}
|
||||
|
||||
|
|
@ -183,6 +178,11 @@ type NotificationData interface {
|
|||
SelectMaxID(ctx context.Context) (int64, error)
|
||||
}
|
||||
|
||||
type Ignores interface {
|
||||
SelectIgnores(ctx context.Context, userID string) (*types.IgnoredUsers, error)
|
||||
UpsertIgnores(ctx context.Context, userID string, ignores *types.IgnoredUsers) error
|
||||
}
|
||||
|
||||
type Presence interface {
|
||||
UpsertPresence(ctx context.Context, txn *sql.Tx, userID string, statusMsg *string, presence types.Presence, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (pos types.StreamPosition, err error)
|
||||
GetPresenceForUser(ctx context.Context, txn *sql.Tx, userID string) (presence *types.PresenceInternal, err error)
|
||||
|
|
|
|||
82
syncapi/storage/tables/output_room_events_test.go
Normal file
82
syncapi/storage/tables/output_room_events_test.go
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
package tables_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
)
|
||||
|
||||
func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events, *sql.DB, func()) {
|
||||
t.Helper()
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open db: %s", err)
|
||||
}
|
||||
|
||||
var tab tables.Events
|
||||
switch dbType {
|
||||
case test.DBTypePostgres:
|
||||
tab, err = postgres.NewPostgresEventsTable(db)
|
||||
case test.DBTypeSQLite:
|
||||
var stream sqlite3.StreamIDStatements
|
||||
if err = stream.Prepare(db); err != nil {
|
||||
t.Fatalf("failed to prepare stream stmts: %s", err)
|
||||
}
|
||||
tab, err = sqlite3.NewSqliteEventsTable(db, &stream)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make new table: %s", err)
|
||||
}
|
||||
return tab, db, close
|
||||
}
|
||||
|
||||
func TestOutputRoomEventsTable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
alice := test.NewUser()
|
||||
room := test.NewRoom(t, alice)
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
tab, db, close := newOutputRoomEventsTable(t, dbType)
|
||||
defer close()
|
||||
events := room.Events()
|
||||
err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
|
||||
for _, ev := range events {
|
||||
_, err := tab.InsertEvent(ctx, txn, ev, nil, nil, nil, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to InsertEvent: %s", err)
|
||||
}
|
||||
}
|
||||
// order = 2,0,3,1
|
||||
wantEventIDs := []string{
|
||||
events[2].EventID(), events[0].EventID(), events[3].EventID(), events[1].EventID(),
|
||||
}
|
||||
gotEvents, err := tab.SelectEvents(ctx, txn, wantEventIDs, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to SelectEvents: %s", err)
|
||||
}
|
||||
gotEventIDs := make([]string, len(gotEvents))
|
||||
for i := range gotEvents {
|
||||
gotEventIDs[i] = gotEvents[i].EventID()
|
||||
}
|
||||
if !reflect.DeepEqual(gotEventIDs, wantEventIDs) {
|
||||
return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventIDs)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
91
syncapi/storage/tables/topology_test.go
Normal file
91
syncapi/storage/tables/topology_test.go
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
package tables_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
)
|
||||
|
||||
func newTopologyTable(t *testing.T, dbType test.DBType) (tables.Topology, *sql.DB, func()) {
|
||||
t.Helper()
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open db: %s", err)
|
||||
}
|
||||
|
||||
var tab tables.Topology
|
||||
switch dbType {
|
||||
case test.DBTypePostgres:
|
||||
tab, err = postgres.NewPostgresTopologyTable(db)
|
||||
case test.DBTypeSQLite:
|
||||
tab, err = sqlite3.NewSqliteTopologyTable(db)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make new table: %s", err)
|
||||
}
|
||||
return tab, db, close
|
||||
}
|
||||
|
||||
func TestTopologyTable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
alice := test.NewUser()
|
||||
room := test.NewRoom(t, alice)
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
tab, db, close := newTopologyTable(t, dbType)
|
||||
defer close()
|
||||
events := room.Events()
|
||||
err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
|
||||
var highestPos types.StreamPosition
|
||||
for i, ev := range events {
|
||||
topoPos, err := tab.InsertEventInTopology(ctx, txn, ev, types.StreamPosition(i))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to InsertEventInTopology: %s", err)
|
||||
}
|
||||
// topo pos = depth, depth starts at 1, hence 1+i
|
||||
if topoPos != types.StreamPosition(1+i) {
|
||||
return fmt.Errorf("got topo pos %d want %d", topoPos, 1+i)
|
||||
}
|
||||
highestPos = topoPos + 1
|
||||
}
|
||||
// check ordering works without limit
|
||||
eventIDs, err := tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
|
||||
}
|
||||
test.AssertEventIDsEqual(t, eventIDs, events[:])
|
||||
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
|
||||
}
|
||||
test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[:]))
|
||||
// check ordering works with limit
|
||||
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
|
||||
}
|
||||
test.AssertEventIDsEqual(t, eventIDs, events[:3])
|
||||
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
|
||||
}
|
||||
test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[len(events)-3:]))
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -54,6 +54,10 @@ func (p *InviteStreamProvider) IncrementalSync(
|
|||
}
|
||||
|
||||
for roomID, inviteEvent := range invites {
|
||||
// skip ignored user events
|
||||
if _, ok := req.IgnoredUsers.List[inviteEvent.Sender()]; ok {
|
||||
continue
|
||||
}
|
||||
ir := types.NewInviteResponse(inviteEvent)
|
||||
req.Response.Rooms.Invite[roomID] = *ir
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package streams
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
|
@ -25,6 +26,7 @@ type PDUStreamProvider struct {
|
|||
|
||||
tasks chan func()
|
||||
workers atomic.Int32
|
||||
userAPI userapi.UserInternalAPI
|
||||
}
|
||||
|
||||
func (p *PDUStreamProvider) worker() {
|
||||
|
|
@ -87,6 +89,10 @@ func (p *PDUStreamProvider) CompleteSync(
|
|||
stateFilter := req.Filter.Room.State
|
||||
eventFilter := req.Filter.Room.Timeline
|
||||
|
||||
if err = p.addIgnoredUsersToFilter(ctx, req, &eventFilter); err != nil {
|
||||
req.Log.WithError(err).Error("unable to update event filter with ignored users")
|
||||
}
|
||||
|
||||
// Build up a /sync response. Add joined rooms.
|
||||
var reqMutex sync.Mutex
|
||||
var reqWaitGroup sync.WaitGroup
|
||||
|
|
@ -175,6 +181,10 @@ func (p *PDUStreamProvider) IncrementalSync(
|
|||
return to
|
||||
}
|
||||
|
||||
if err = p.addIgnoredUsersToFilter(ctx, req, &eventFilter); err != nil {
|
||||
req.Log.WithError(err).Error("unable to update event filter with ignored users")
|
||||
}
|
||||
|
||||
newPos = from
|
||||
for _, delta := range stateDeltas {
|
||||
var pos types.StreamPosition
|
||||
|
|
@ -402,6 +412,27 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
|||
return jr, nil
|
||||
}
|
||||
|
||||
// addIgnoredUsersToFilter adds ignored users to the eventfilter and
|
||||
// the syncreq itself for further use in streams.
|
||||
func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error {
|
||||
ignores, err := p.DB.IgnoresForUser(ctx, req.Device.UserID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
req.IgnoredUsers = *ignores
|
||||
userList := make([]string, 0, len(ignores.List))
|
||||
for userID := range ignores.List {
|
||||
userList = append(userList, userID)
|
||||
}
|
||||
if len(userList) > 0 {
|
||||
eventFilter.NotSenders = &userList
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeDuplicates(stateEvents, recentEvents []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
|
||||
for _, recentEv := range recentEvents {
|
||||
if recentEv.StateKey() == nil {
|
||||
|
|
|
|||
|
|
@ -111,10 +111,15 @@ func (p *PresenceStreamProvider) IncrementalSync(
|
|||
continue
|
||||
}
|
||||
}
|
||||
presence.ClientFields.LastActiveAgo = presence.LastActiveAgo()
|
||||
if presence.ClientFields.Presence == "online" {
|
||||
currentlyActive := presence.CurrentlyActive()
|
||||
presence.ClientFields.CurrentlyActive = ¤tlyActive
|
||||
|
||||
if _, known := types.PresenceFromString(presence.ClientFields.Presence); known {
|
||||
presence.ClientFields.LastActiveAgo = presence.LastActiveAgo()
|
||||
if presence.ClientFields.Presence == "online" {
|
||||
currentlyActive := presence.CurrentlyActive()
|
||||
presence.ClientFields.CurrentlyActive = ¤tlyActive
|
||||
}
|
||||
} else {
|
||||
presence.ClientFields.Presence = "offline"
|
||||
}
|
||||
|
||||
content, err := json.Marshal(presence.ClientFields)
|
||||
|
|
|
|||
|
|
@ -54,6 +54,10 @@ func (p *ReceiptStreamProvider) IncrementalSync(
|
|||
// Group receipts by room, so we can create one ClientEvent for every room
|
||||
receiptsByRoom := make(map[string][]types.OutputReceiptEvent)
|
||||
for _, receipt := range receipts {
|
||||
// skip ignored user events
|
||||
if _, ok := req.IgnoredUsers.List[receipt.UserID]; ok {
|
||||
continue
|
||||
}
|
||||
receiptsByRoom[receipt.RoomID] = append(receiptsByRoom[receipt.RoomID], receipt)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -48,6 +48,10 @@ func (p *SendToDeviceStreamProvider) IncrementalSync(
|
|||
|
||||
// Add the updates into the sync response.
|
||||
for _, event := range events {
|
||||
// skip ignored user events
|
||||
if _, ok := req.IgnoredUsers.List[event.Sender]; ok {
|
||||
continue
|
||||
}
|
||||
req.Response.ToDevice.Events = append(req.Response.ToDevice.Events, event.SendToDeviceEvent)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,11 +40,18 @@ func (p *TypingStreamProvider) IncrementalSync(
|
|||
if users, updated := p.EDUCache.GetTypingUsersIfUpdatedAfter(
|
||||
roomID, int64(from),
|
||||
); updated {
|
||||
typingUsers := make([]string, 0, len(users))
|
||||
for i := range users {
|
||||
// skip ignored user events
|
||||
if _, ok := req.IgnoredUsers.List[users[i]]; !ok {
|
||||
typingUsers = append(typingUsers, users[i])
|
||||
}
|
||||
}
|
||||
ev := gomatrixserverlib.ClientEvent{
|
||||
Type: gomatrixserverlib.MTyping,
|
||||
}
|
||||
ev.Content, err = json.Marshal(map[string]interface{}{
|
||||
"user_ids": users,
|
||||
"user_ids": typingUsers,
|
||||
})
|
||||
if err != nil {
|
||||
req.Log.WithError(err).Error("json.Marshal failed")
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ func NewSyncStreamProviders(
|
|||
streams := &Streams{
|
||||
PDUStreamProvider: &PDUStreamProvider{
|
||||
StreamProvider: StreamProvider{DB: d},
|
||||
userAPI: userAPI,
|
||||
},
|
||||
TypingStreamProvider: &TypingStreamProvider{
|
||||
StreamProvider: StreamProvider{DB: d},
|
||||
|
|
|
|||
|
|
@ -67,6 +67,9 @@ func NewRequestPool(
|
|||
streams *streams.Streams, notifier *notifier.Notifier,
|
||||
producer PresencePublisher,
|
||||
) *RequestPool {
|
||||
prometheus.MustRegister(
|
||||
activeSyncRequests, waitingSyncRequests,
|
||||
)
|
||||
rp := &RequestPool{
|
||||
db: db,
|
||||
cfg: cfg,
|
||||
|
|
@ -122,10 +125,8 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user
|
|||
|
||||
presenceID, ok := types.PresenceFromString(presence)
|
||||
if !ok { // this should almost never happen
|
||||
logrus.Errorf("unknown presence '%s'", presence)
|
||||
return
|
||||
}
|
||||
|
||||
newPresence := types.PresenceInternal{
|
||||
ClientFields: types.PresenceClientResponse{
|
||||
Presence: presenceID.String(),
|
||||
|
|
@ -186,12 +187,6 @@ func (rp *RequestPool) updateLastSeen(req *http.Request, device *userapi.Device)
|
|||
rp.lastseen.Store(device.UserID+device.ID, time.Now())
|
||||
}
|
||||
|
||||
func init() {
|
||||
prometheus.MustRegister(
|
||||
activeSyncRequests, waitingSyncRequests,
|
||||
)
|
||||
}
|
||||
|
||||
var activeSyncRequests = prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: "dendrite",
|
||||
|
|
|
|||
|
|
@ -21,25 +21,41 @@ import (
|
|||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
//go:generate stringer -type=Presence -linecomment
|
||||
type Presence uint8
|
||||
|
||||
const (
|
||||
PresenceUnavailable Presence = iota + 1 // unavailable
|
||||
PresenceOnline // online
|
||||
PresenceOffline // offline
|
||||
PresenceUnknown Presence = iota
|
||||
PresenceUnavailable // unavailable
|
||||
PresenceOnline // online
|
||||
PresenceOffline // offline
|
||||
)
|
||||
|
||||
func (p Presence) String() string {
|
||||
switch p {
|
||||
case PresenceUnavailable:
|
||||
return "unavailable"
|
||||
case PresenceOnline:
|
||||
return "online"
|
||||
case PresenceOffline:
|
||||
return "offline"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// PresenceFromString returns the integer representation of the given input presence.
|
||||
// Returns false for ok, if input is not a valid presence value.
|
||||
func PresenceFromString(input string) (p Presence, ok bool) {
|
||||
for i := 0; i < len(_Presence_index)-1; i++ {
|
||||
l, r := _Presence_index[i], _Presence_index[i+1]
|
||||
if strings.EqualFold(input, _Presence_name[l:r]) {
|
||||
return Presence(i + 1), true
|
||||
}
|
||||
func PresenceFromString(input string) (Presence, bool) {
|
||||
switch strings.ToLower(input) {
|
||||
case "unavailable":
|
||||
return PresenceUnavailable, true
|
||||
case "online":
|
||||
return PresenceOnline, true
|
||||
case "offline":
|
||||
return PresenceOffline, true
|
||||
default:
|
||||
return PresenceUnknown, false
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
type PresenceInternal struct {
|
||||
|
|
|
|||
|
|
@ -1,26 +0,0 @@
|
|||
// Code generated by "stringer -type=Presence -linecomment"; DO NOT EDIT.
|
||||
|
||||
package types
|
||||
|
||||
import "strconv"
|
||||
|
||||
func _() {
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
var x [1]struct{}
|
||||
_ = x[PresenceUnavailable-1]
|
||||
_ = x[PresenceOnline-2]
|
||||
_ = x[PresenceOffline-3]
|
||||
}
|
||||
|
||||
const _Presence_name = "unavailableonlineoffline"
|
||||
|
||||
var _Presence_index = [...]uint8{0, 11, 17, 24}
|
||||
|
||||
func (i Presence) String() string {
|
||||
i -= 1
|
||||
if i >= Presence(len(_Presence_index)-1) {
|
||||
return "Presence(" + strconv.FormatInt(int64(i+1), 10) + ")"
|
||||
}
|
||||
return _Presence_name[_Presence_index[i]:_Presence_index[i+1]]
|
||||
}
|
||||
|
|
@ -1,42 +0,0 @@
|
|||
package types
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestPresenceFromString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantStatus Presence
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "presence unavailable",
|
||||
input: "unavailable",
|
||||
wantStatus: PresenceUnavailable,
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "presence online",
|
||||
input: "OnLINE",
|
||||
wantStatus: PresenceOnline,
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "unknown presence",
|
||||
input: "unknown",
|
||||
wantStatus: 0,
|
||||
wantOk: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1 := PresenceFromString(tt.input)
|
||||
if got != tt.wantStatus {
|
||||
t.Errorf("PresenceFromString() got = %v, want %v", got, tt.wantStatus)
|
||||
}
|
||||
if got1 != tt.wantOk {
|
||||
t.Errorf("PresenceFromString() got1 = %v, want %v", got1, tt.wantOk)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -21,6 +21,8 @@ type SyncRequest struct {
|
|||
|
||||
// Updated by the PDU stream.
|
||||
Rooms map[string]string
|
||||
// Updated by the PDU stream.
|
||||
IgnoredUsers IgnoredUsers
|
||||
}
|
||||
|
||||
type StreamProvider interface {
|
||||
|
|
|
|||
|
|
@ -518,3 +518,7 @@ type OutputSendToDeviceEvent struct {
|
|||
DeviceID string `json:"device_id"`
|
||||
gomatrixserverlib.SendToDeviceEvent
|
||||
}
|
||||
|
||||
type IgnoredUsers struct {
|
||||
List map[string]interface{} `json:"ignored_users"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,10 +2,6 @@
|
|||
|
||||
Latest account data appears in v2 /sync
|
||||
|
||||
# Blacklisted because we don't support ignores yet
|
||||
|
||||
Ignore invite in incremental sync
|
||||
|
||||
# Relies on a rejected PL event which will never be accepted into the DAG
|
||||
|
||||
# Caused by <https://github.com/matrix-org/sytest/pull/911>
|
||||
|
|
|
|||
|
|
@ -284,8 +284,6 @@ local user can join room with version 4
|
|||
remote user can join room with version 3
|
||||
remote user can join room with version 4
|
||||
Remote user can backfill in a room with version 4
|
||||
# We don't support ignores yet, so ignore this for now - ha ha.
|
||||
# Ignore invite in incremental sync
|
||||
Outbound federation can send invites via v2 API
|
||||
User can invite local user to room with version 3
|
||||
User can invite local user to room with version 4
|
||||
|
|
@ -695,4 +693,9 @@ Presence changes to UNAVAILABLE are reported to remote room members
|
|||
New federated private chats get full presence information (SYN-115)
|
||||
/upgrade copies >100 power levels to the new room
|
||||
Room state after a rejected message event is the same as before
|
||||
Room state after a rejected state event is the same as before
|
||||
Room state after a rejected state event is the same as before
|
||||
Ignore user in existing room
|
||||
Ignore invite in full sync
|
||||
Ignore invite in incremental sync
|
||||
A filtered timeline reaches its limit
|
||||
A change to displayname should not result in a full state sync
|
||||
128
test/db.go
Normal file
128
test/db.go
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type DBType int
|
||||
|
||||
var DBTypeSQLite DBType = 1
|
||||
var DBTypePostgres DBType = 2
|
||||
|
||||
var Quiet = false
|
||||
|
||||
func createLocalDB(dbName string) string {
|
||||
if !Quiet {
|
||||
fmt.Println("Note: tests require a postgres install accessible to the current user")
|
||||
}
|
||||
createDB := exec.Command("createdb", dbName)
|
||||
if !Quiet {
|
||||
createDB.Stdout = os.Stdout
|
||||
createDB.Stderr = os.Stderr
|
||||
}
|
||||
err := createDB.Run()
|
||||
if err != nil && !Quiet {
|
||||
fmt.Println("createLocalDB returned error:", err)
|
||||
}
|
||||
return dbName
|
||||
}
|
||||
|
||||
func currentUser() string {
|
||||
user, err := user.Current()
|
||||
if err != nil {
|
||||
if !Quiet {
|
||||
fmt.Println("cannot get current user: ", err)
|
||||
}
|
||||
os.Exit(2)
|
||||
}
|
||||
return user.Username
|
||||
}
|
||||
|
||||
// Prepare a sqlite or postgres connection string for testing.
|
||||
// Returns the connection string to use and a close function which must be called when the test finishes.
|
||||
// Calling this function twice will return the same database, which will have data from previous tests
|
||||
// unless close() is called.
|
||||
// TODO: namespace for concurrent package tests
|
||||
func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, close func()) {
|
||||
if dbType == DBTypeSQLite {
|
||||
dbname := "dendrite_test.db"
|
||||
return fmt.Sprintf("file:%s", dbname), func() {
|
||||
err := os.Remove(dbname)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to cleanup sqlite db '%s': %s", dbname, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Required vars: user and db
|
||||
// We'll try to infer from the local env if they are missing
|
||||
user := os.Getenv("POSTGRES_USER")
|
||||
if user == "" {
|
||||
user = currentUser()
|
||||
}
|
||||
dbName := os.Getenv("POSTGRES_DB")
|
||||
if dbName == "" {
|
||||
dbName = createLocalDB("dendrite_test")
|
||||
}
|
||||
connStr = fmt.Sprintf(
|
||||
"user=%s dbname=%s sslmode=disable",
|
||||
user, dbName,
|
||||
)
|
||||
// optional vars, used in CI
|
||||
password := os.Getenv("POSTGRES_PASSWORD")
|
||||
if password != "" {
|
||||
connStr += fmt.Sprintf(" password=%s", password)
|
||||
}
|
||||
host := os.Getenv("POSTGRES_HOST")
|
||||
if host != "" {
|
||||
connStr += fmt.Sprintf(" host=%s", host)
|
||||
}
|
||||
|
||||
return connStr, func() {
|
||||
// Drop all tables on the database to get a fresh instance
|
||||
db, err := sql.Open("postgres", connStr)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to postgres db '%s': %s", connStr, err)
|
||||
}
|
||||
_, err = db.Exec(`DROP SCHEMA public CASCADE;
|
||||
CREATE SCHEMA public;`)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to cleanup postgres db '%s': %s", connStr, err)
|
||||
}
|
||||
_ = db.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Creates subtests with each known DBType
|
||||
func WithAllDatabases(t *testing.T, testFn func(t *testing.T, db DBType)) {
|
||||
dbs := map[string]DBType{
|
||||
"postgres": DBTypePostgres,
|
||||
"sqlite": DBTypeSQLite,
|
||||
}
|
||||
for dbName, dbType := range dbs {
|
||||
dbt := dbType
|
||||
t.Run(dbName, func(tt *testing.T) {
|
||||
tt.Parallel()
|
||||
testFn(tt, dbt)
|
||||
})
|
||||
}
|
||||
}
|
||||
90
test/event.go
Normal file
90
test/event.go
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ed25519"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type eventMods struct {
|
||||
originServerTS time.Time
|
||||
origin gomatrixserverlib.ServerName
|
||||
stateKey *string
|
||||
unsigned interface{}
|
||||
keyID gomatrixserverlib.KeyID
|
||||
privKey ed25519.PrivateKey
|
||||
}
|
||||
|
||||
type eventModifier func(e *eventMods)
|
||||
|
||||
func WithTimestamp(ts time.Time) eventModifier {
|
||||
return func(e *eventMods) {
|
||||
e.originServerTS = ts
|
||||
}
|
||||
}
|
||||
|
||||
func WithStateKey(skey string) eventModifier {
|
||||
return func(e *eventMods) {
|
||||
e.stateKey = &skey
|
||||
}
|
||||
}
|
||||
|
||||
func WithUnsigned(unsigned interface{}) eventModifier {
|
||||
return func(e *eventMods) {
|
||||
e.unsigned = unsigned
|
||||
}
|
||||
}
|
||||
|
||||
// Reverse a list of events
|
||||
func Reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
|
||||
out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
|
||||
for i := 0; i < len(in); i++ {
|
||||
out[i] = in[len(in)-i-1]
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func AssertEventIDsEqual(t *testing.T, gotEventIDs []string, wants []*gomatrixserverlib.HeaderedEvent) {
|
||||
t.Helper()
|
||||
if len(gotEventIDs) != len(wants) {
|
||||
t.Fatalf("length mismatch: got %d events, want %d", len(gotEventIDs), len(wants))
|
||||
}
|
||||
for i := range wants {
|
||||
w := wants[i].EventID()
|
||||
g := gotEventIDs[i]
|
||||
if w != g {
|
||||
t.Errorf("event at index %d mismatch:\ngot %s\n\nwant %s", i, string(g), string(w))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AssertEventsEqual(t *testing.T, gots, wants []*gomatrixserverlib.HeaderedEvent) {
|
||||
t.Helper()
|
||||
if len(gots) != len(wants) {
|
||||
t.Fatalf("length mismatch: got %d events, want %d", len(gots), len(wants))
|
||||
}
|
||||
for i := range wants {
|
||||
w := wants[i].JSON()
|
||||
g := gots[i].JSON()
|
||||
if !bytes.Equal(w, g) {
|
||||
t.Errorf("event at index %d mismatch:\ngot %s\n\nwant %s", i, string(g), string(w))
|
||||
}
|
||||
}
|
||||
}
|
||||
223
test/room.go
Normal file
223
test/room.go
Normal file
|
|
@ -0,0 +1,223 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package test
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type Preset int
|
||||
|
||||
var (
|
||||
PresetNone Preset = 0
|
||||
PresetPrivateChat Preset = 1
|
||||
PresetPublicChat Preset = 2
|
||||
PresetTrustedPrivateChat Preset = 3
|
||||
|
||||
roomIDCounter = int64(0)
|
||||
|
||||
testKeyID = gomatrixserverlib.KeyID("ed25519:test")
|
||||
testPrivateKey = ed25519.NewKeyFromSeed([]byte{
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
||||
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
|
||||
})
|
||||
)
|
||||
|
||||
type Room struct {
|
||||
ID string
|
||||
Version gomatrixserverlib.RoomVersion
|
||||
preset Preset
|
||||
creator *User
|
||||
|
||||
authEvents gomatrixserverlib.AuthEvents
|
||||
events []*gomatrixserverlib.HeaderedEvent
|
||||
}
|
||||
|
||||
// Create a new test room. Automatically creates the initial create events.
|
||||
func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room {
|
||||
t.Helper()
|
||||
counter := atomic.AddInt64(&roomIDCounter, 1)
|
||||
|
||||
// set defaults then let roomModifiers override
|
||||
r := &Room{
|
||||
ID: fmt.Sprintf("!%d:localhost", counter),
|
||||
creator: creator,
|
||||
authEvents: gomatrixserverlib.NewAuthEvents(nil),
|
||||
preset: PresetPublicChat,
|
||||
Version: gomatrixserverlib.RoomVersionV9,
|
||||
}
|
||||
for _, m := range modifiers {
|
||||
m(t, r)
|
||||
}
|
||||
r.insertCreateEvents(t)
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Room) insertCreateEvents(t *testing.T) {
|
||||
t.Helper()
|
||||
var joinRule gomatrixserverlib.JoinRuleContent
|
||||
var hisVis gomatrixserverlib.HistoryVisibilityContent
|
||||
plContent := eventutil.InitialPowerLevelsContent(r.creator.ID)
|
||||
switch r.preset {
|
||||
case PresetTrustedPrivateChat:
|
||||
fallthrough
|
||||
case PresetPrivateChat:
|
||||
joinRule.JoinRule = "invite"
|
||||
hisVis.HistoryVisibility = "shared"
|
||||
case PresetPublicChat:
|
||||
joinRule.JoinRule = "public"
|
||||
hisVis.HistoryVisibility = "shared"
|
||||
}
|
||||
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomCreate, map[string]interface{}{
|
||||
"creator": r.creator.ID,
|
||||
"room_version": r.Version,
|
||||
}, WithStateKey(""))
|
||||
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "join",
|
||||
}, WithStateKey(r.creator.ID))
|
||||
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey(""))
|
||||
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey(""))
|
||||
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey(""))
|
||||
}
|
||||
|
||||
// Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe.
|
||||
func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, content interface{}, mods ...eventModifier) *gomatrixserverlib.HeaderedEvent {
|
||||
t.Helper()
|
||||
depth := 1 + len(r.events) // depth starts at 1
|
||||
|
||||
// possible event modifiers (optional fields)
|
||||
mod := &eventMods{}
|
||||
for _, m := range mods {
|
||||
m(mod)
|
||||
}
|
||||
|
||||
if mod.privKey == nil {
|
||||
mod.privKey = testPrivateKey
|
||||
}
|
||||
if mod.keyID == "" {
|
||||
mod.keyID = testKeyID
|
||||
}
|
||||
if mod.originServerTS.IsZero() {
|
||||
mod.originServerTS = time.Now()
|
||||
}
|
||||
if mod.origin == "" {
|
||||
mod.origin = gomatrixserverlib.ServerName("localhost")
|
||||
}
|
||||
|
||||
var unsigned gomatrixserverlib.RawJSON
|
||||
var err error
|
||||
if mod.unsigned != nil {
|
||||
unsigned, err = json.Marshal(mod.unsigned)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateEvent[%s]: failed to marshal unsigned field: %s", eventType, err)
|
||||
}
|
||||
}
|
||||
|
||||
builder := &gomatrixserverlib.EventBuilder{
|
||||
Sender: creator.ID,
|
||||
RoomID: r.ID,
|
||||
Type: eventType,
|
||||
StateKey: mod.stateKey,
|
||||
Depth: int64(depth),
|
||||
Unsigned: unsigned,
|
||||
}
|
||||
err = builder.SetContent(content)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateEvent[%s]: failed to SetContent: %s", eventType, err)
|
||||
}
|
||||
if depth > 1 {
|
||||
builder.PrevEvents = []gomatrixserverlib.EventReference{r.events[len(r.events)-1].EventReference()}
|
||||
}
|
||||
|
||||
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateEvent[%s]: failed to StateNeededForEventBuilder: %s", eventType, err)
|
||||
}
|
||||
refs, err := eventsNeeded.AuthEventReferences(&r.authEvents)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateEvent[%s]: failed to AuthEventReferences: %s", eventType, err)
|
||||
}
|
||||
builder.AuthEvents = refs
|
||||
ev, err := builder.Build(
|
||||
mod.originServerTS, mod.origin, mod.keyID,
|
||||
mod.privKey, r.Version,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateEvent[%s]: failed to build event: %s", eventType, err)
|
||||
}
|
||||
if err = gomatrixserverlib.Allowed(ev, &r.authEvents); err != nil {
|
||||
t.Fatalf("CreateEvent[%s]: failed to verify event was allowed: %s", eventType, err)
|
||||
}
|
||||
return ev.Headered(r.Version)
|
||||
}
|
||||
|
||||
// Add a new event to this room DAG. Not thread-safe.
|
||||
func (r *Room) InsertEvent(t *testing.T, he *gomatrixserverlib.HeaderedEvent) {
|
||||
t.Helper()
|
||||
// Add the event to the list of auth events
|
||||
r.events = append(r.events, he)
|
||||
if he.StateKey() != nil {
|
||||
err := r.authEvents.AddEvent(he.Unwrap())
|
||||
if err != nil {
|
||||
t.Fatalf("InsertEvent: failed to add event to auth events: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Room) Events() []*gomatrixserverlib.HeaderedEvent {
|
||||
return r.events
|
||||
}
|
||||
|
||||
func (r *Room) CreateAndInsert(t *testing.T, creator *User, eventType string, content interface{}, mods ...eventModifier) *gomatrixserverlib.HeaderedEvent {
|
||||
t.Helper()
|
||||
he := r.CreateEvent(t, creator, eventType, content, mods...)
|
||||
r.InsertEvent(t, he)
|
||||
return he
|
||||
}
|
||||
|
||||
// All room modifiers are below
|
||||
|
||||
type roomModifier func(t *testing.T, r *Room)
|
||||
|
||||
func RoomPreset(p Preset) roomModifier {
|
||||
return func(t *testing.T, r *Room) {
|
||||
switch p {
|
||||
case PresetPrivateChat:
|
||||
fallthrough
|
||||
case PresetPublicChat:
|
||||
fallthrough
|
||||
case PresetTrustedPrivateChat:
|
||||
fallthrough
|
||||
case PresetNone:
|
||||
r.preset = p
|
||||
default:
|
||||
t.Errorf("invalid RoomPreset: %v", p)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier {
|
||||
return func(t *testing.T, r *Room) {
|
||||
r.Version = ver
|
||||
}
|
||||
}
|
||||
36
test/user.go
Normal file
36
test/user.go
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var (
|
||||
userIDCounter = int64(0)
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID string
|
||||
}
|
||||
|
||||
func NewUser() *User {
|
||||
counter := atomic.AddInt64(&userIDCounter, 1)
|
||||
u := &User{
|
||||
ID: fmt.Sprintf("@%d:localhost", counter),
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
|
@ -404,8 +404,24 @@ func (s *OutputStreamEventConsumer) evaluatePushRules(ctx context.Context, event
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
// Get accountdata to check if the event.Sender() is ignored by mem.LocalPart
|
||||
data, err := s.db.GetAccountDataByType(ctx, mem.Localpart, "", "m.ignored_user_list")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if data != nil {
|
||||
ignored := types.IgnoredUsers{}
|
||||
err = json.Unmarshal(data, &ignored)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sender := event.Sender()
|
||||
if _, ok := ignored.List[sender]; ok {
|
||||
return nil, fmt.Errorf("user %s is ignored", sender)
|
||||
}
|
||||
}
|
||||
var res api.QueryPushRulesResponse
|
||||
if err := s.userAPI.QueryPushRules(ctx, &api.QueryPushRulesRequest{UserID: mem.UserID}, &res); err != nil {
|
||||
if err = s.userAPI.QueryPushRules(ctx, &api.QueryPushRulesRequest{UserID: mem.UserID}, &res); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -369,6 +369,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
|||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(QueryAccountAvailabilityPath,
|
||||
httputil.MakeInternalAPI("queryAccountAvailability", func(req *http.Request) util.JSONResponse {
|
||||
request := api.QueryAccountAvailabilityRequest{}
|
||||
response := api.QueryAccountAvailabilityResponse{}
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if err := s.QueryAccountAvailability(req.Context(), &request, &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(QueryAccountByPasswordPath,
|
||||
httputil.MakeInternalAPI("queryAccountByPassword", func(req *http.Request) util.JSONResponse {
|
||||
request := api.QueryAccountByPasswordRequest{}
|
||||
|
|
|
|||
Loading…
Reference in a new issue