diff --git a/appservice/api/query.go b/appservice/api/query.go index 4a4e31a99..d36e138cb 100644 --- a/appservice/api/query.go +++ b/appservice/api/query.go @@ -84,10 +84,10 @@ type AppServiceQueryAPI interface { } // AppServiceRoomAliasExistsPath is the HTTP path for the RoomAliasExists API -const AppServiceRoomAliasExistsPath = "/api/appservice/RoomAliasExists" +const AppServiceRoomAliasExistsPath = "/appservice/RoomAliasExists" // AppServiceUserIDExistsPath is the HTTP path for the UserIDExists API -const AppServiceUserIDExistsPath = "/api/appservice/UserIDExists" +const AppServiceUserIDExistsPath = "/appservice/UserIDExists" // httpAppServiceQueryAPI contains the URL to an appservice query API and a // reference to a httpClient used to reach it diff --git a/appservice/appservice.go b/appservice/appservice.go index 3c67ce9c1..68cf52e79 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -16,6 +16,7 @@ package appservice import ( "context" + "errors" "net/http" "sync" "time" @@ -29,6 +30,7 @@ import ( "github.com/matrix-org/dendrite/appservice/workers" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/basecomponent" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/transactions" @@ -82,9 +84,7 @@ func SetupAppServiceAPIComponent( Cfg: base.Cfg, } - if base.EnableHTTPAPIs { - appserviceQueryAPI.SetupHTTP(http.DefaultServeMux) - } + appserviceQueryAPI.SetupHTTP(base.InternalAPIMux) consumer := consumers.NewOutputRoomEventConsumer( base.Cfg, base.KafkaConsumer, accountsDB, appserviceDB, @@ -101,7 +101,7 @@ func SetupAppServiceAPIComponent( // Set up HTTP Endpoints routing.Setup( - base.APIMux, base.Cfg, rsAPI, + base.PublicAPIMux, base.Cfg, rsAPI, accountsDB, federation, transactionsCache, ) @@ -119,12 +119,12 @@ func generateAppServiceAccount( ctx := context.Background() // Create an account for the application service - acc, err := accountsDB.CreateAccount(ctx, as.SenderLocalpart, "", as.ID) + _, err := accountsDB.CreateAccount(ctx, as.SenderLocalpart, "", as.ID) if err != nil { + if errors.Is(err, internal.ErrUserExists) { // This account already exists + return nil + } return err - } else if acc == nil { - // This account already exists - return nil } // Create a dummy device with a dummy token for the application service diff --git a/appservice/query/query.go b/appservice/query/query.go index a61997b42..812ca9f49 100644 --- a/appservice/query/query.go +++ b/appservice/query/query.go @@ -23,6 +23,7 @@ import ( "net/url" "time" + "github.com/gorilla/mux" "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/config" @@ -182,8 +183,8 @@ func makeHTTPClient() *http.Client { // SetupHTTP adds the AppServiceQueryPAI handlers to the http.ServeMux. This // handles and muxes incoming api requests the to internal AppServiceQueryAPI. -func (a *AppServiceQueryAPI) SetupHTTP(servMux *http.ServeMux) { - servMux.Handle( +func (a *AppServiceQueryAPI) SetupHTTP(internalAPIMux *mux.Router) { + internalAPIMux.Handle( api.AppServiceRoomAliasExistsPath, internal.MakeInternalAPI("appserviceRoomAliasExists", func(req *http.Request) util.JSONResponse { var request api.RoomAliasExistsRequest @@ -197,7 +198,7 @@ func (a *AppServiceQueryAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - servMux.Handle( + internalAPIMux.Handle( api.AppServiceUserIDExistsPath, internal.MakeInternalAPI("appserviceUserIDExists", func(req *http.Request) util.JSONResponse { var request api.UserIDExistsRequest diff --git a/appservice/routing/routing.go b/appservice/routing/routing.go index ac1ff32ea..e5ed7ab40 100644 --- a/appservice/routing/routing.go +++ b/appservice/routing/routing.go @@ -27,7 +27,7 @@ import ( "github.com/matrix-org/util" ) -const pathPrefixApp = "/_matrix/app/v1" +const pathPrefixApp = "/app/v1" // Setup registers HTTP handlers with the given ServeMux. It also supplies the given http.Client // to clients which need to make outbound HTTP requests. diff --git a/appservice/storage/storage_wasm.go b/appservice/storage/storage_wasm.go index a6144b435..de1acf92f 100644 --- a/appservice/storage/storage_wasm.go +++ b/appservice/storage/storage_wasm.go @@ -34,7 +34,7 @@ func NewDatabase( case "postgres": return nil, fmt.Errorf("Cannot use postgres implementation") case "file": - return sqlite3.NewDatabase(dataSourceName) + return sqlite3.NewDatabase(uri.Path) default: return nil, fmt.Errorf("Cannot use postgres implementation") } diff --git a/build/docker/DendriteJS.Dockerfile b/build/docker/DendriteJS.Dockerfile new file mode 100644 index 000000000..9a84a370a --- /dev/null +++ b/build/docker/DendriteJS.Dockerfile @@ -0,0 +1,107 @@ +# This dockerfile will build dendritejs and hook it up to riot-web, build that then dump the +# resulting HTML/JS onto an nginx container for hosting. It requires no specific build context +# as it pulls archives straight from github branches. +FROM golang:1.13.7-alpine3.11 AS gobuild + +# Download and build dendrite +WORKDIR /build +ADD https://github.com/matrix-org/dendrite/archive/master.tar.gz /build/master.tar.gz +RUN tar xvfz master.tar.gz +WORKDIR /build/dendrite-master +RUN GOOS=js GOARCH=wasm go build -o main.wasm ./cmd/dendritejs + + +FROM node:14-stretch AS jsbuild +# apparently some deps require python +RUN apt-get update && apt-get -y install python + +# Download riot-web and libp2p repos +WORKDIR /build +ADD https://github.com/matrix-org/go-http-js-libp2p/archive/master.tar.gz /build/libp2p.tar.gz +RUN tar xvfz libp2p.tar.gz +ADD https://github.com/vector-im/riot-web/archive/matthew/p2p.tar.gz /build/p2p.tar.gz +RUN tar xvfz p2p.tar.gz + +# Install deps for riot-web, symlink in libp2p repo and build that too +WORKDIR /build/riot-web-matthew-p2p +RUN yarn install +RUN ln -s /build/go-http-js-libp2p-master /build/riot-web-matthew-p2p/node_modules/go-http-js-libp2p +RUN (cd node_modules/go-http-js-libp2p && yarn install) +COPY --from=gobuild /build/dendrite-master/main.wasm ./src/vector/dendrite.wasm +# build it all +RUN yarn build:p2p + +SHELL ["/bin/bash", "-c"] +RUN echo $'\ +{ \n\ + "default_server_config": { \n\ + "m.homeserver": { \n\ + "base_url": "https://p2p.riot.im", \n\ + "server_name": "p2p.riot.im" \n\ + }, \n\ + "m.identity_server": { \n\ + "base_url": "https://vector.im" \n\ + } \n\ + }, \n\ + "disable_custom_urls": false, \n\ + "disable_guests": true, \n\ + "disable_login_language_selector": false, \n\ + "disable_3pid_login": true, \n\ + "brand": "Riot", \n\ + "integrations_ui_url": "https://scalar.vector.im/", \n\ + "integrations_rest_url": "https://scalar.vector.im/api", \n\ + "integrations_widgets_urls": [ \n\ + "https://scalar.vector.im/_matrix/integrations/v1", \n\ + "https://scalar.vector.im/api", \n\ + "https://scalar-staging.vector.im/_matrix/integrations/v1", \n\ + "https://scalar-staging.vector.im/api", \n\ + "https://scalar-staging.riot.im/scalar/api" \n\ + ], \n\ + "integrations_jitsi_widget_url": "https://scalar.vector.im/api/widgets/jitsi.html", \n\ + "bug_report_endpoint_url": "https://riot.im/bugreports/submit", \n\ + "defaultCountryCode": "GB", \n\ + "showLabsSettings": false, \n\ + "features": { \n\ + "feature_pinning": "labs", \n\ + "feature_custom_status": "labs", \n\ + "feature_custom_tags": "labs", \n\ + "feature_state_counters": "labs" \n\ + }, \n\ + "default_federate": true, \n\ + "default_theme": "light", \n\ + "roomDirectory": { \n\ + "servers": [ \n\ + "matrix.org" \n\ + ] \n\ + }, \n\ + "welcomeUserId": "", \n\ + "piwik": { \n\ + "url": "https://piwik.riot.im/", \n\ + "whitelistedHSUrls": ["https://matrix.org"], \n\ + "whitelistedISUrls": ["https://vector.im", "https://matrix.org"], \n\ + "siteId": 1 \n\ + }, \n\ + "enable_presence_by_hs_url": { \n\ + "https://matrix.org": false, \n\ + "https://matrix-client.matrix.org": false \n\ + }, \n\ + "settingDefaults": { \n\ + "breadcrumbs": true \n\ + } \n\ +}' > webapp/config.json + +FROM nginx +# Add "Service-Worker-Allowed: /" header so the worker can sniff traffic on this domain rather +# than just the path this gets hosted under. NB this newline echo syntax only works on bash. +SHELL ["/bin/bash", "-c"] +RUN echo $'\ +server { \n\ + listen 80; \n\ + add_header \'Service-Worker-Allowed\' \'/\'; \n\ + location / { \n\ + root /usr/share/nginx/html; \n\ + index index.html index.htm; \n\ + } \n\ +}' > /etc/nginx/conf.d/default.conf +RUN sed -i 's/}/ application\/wasm wasm;\n}/g' /etc/nginx/mime.types +COPY --from=jsbuild /build/riot-web-matthew-p2p/webapp /usr/share/nginx/html diff --git a/build/docker/hub/config/dendrite-config.yaml b/build/docker/hub/config/dendrite-config.yaml index 23d6479b5..26dc272ab 100644 --- a/build/docker/hub/config/dendrite-config.yaml +++ b/build/docker/hub/config/dendrite-config.yaml @@ -110,11 +110,13 @@ listen: room_server: "room_server:7770" client_api: "client_api:7771" federation_api: "federation_api:7772" + server_key_api: "server_key_api:7778" sync_api: "sync_api:7773" media_api: "media_api:7774" public_rooms_api: "public_rooms_api:7775" federation_sender: "federation_sender:7776" edu_server: "edu_server:7777" + key_server: "key_server:7779" # The configuration for tracing the dendrite components. tracing: diff --git a/build/docker/hub/docker-compose.polylith.yml b/build/docker/hub/docker-compose.polylith.yml index d88e5bfb3..178604093 100644 --- a/build/docker/hub/docker-compose.polylith.yml +++ b/build/docker/hub/docker-compose.polylith.yml @@ -131,7 +131,7 @@ services: - internal key_server: - hostname: key_serverde + hostname: key_server image: matrixdotorg/dendrite:keyserver command: [ "--config=dendrite.yaml" @@ -141,6 +141,17 @@ services: networks: - internal + server_key_api: + hostname: server_key_api + image: matrixdotorg/dendrite:serverkeyapi + command: [ + "--config=dendrite.yaml" + ] + volumes: + - ./config:/etc/dendrite + networks: + - internal + networks: internal: attachable: true diff --git a/build/docker/hub/images-build.sh b/build/docker/hub/images-build.sh index 0c6a0eb7a..b162a0639 100755 --- a/build/docker/hub/images-build.sh +++ b/build/docker/hub/images-build.sh @@ -2,16 +2,17 @@ cd $(git rev-parse --show-toplevel) -docker build -f docker/hub/Dockerfile -t matrixdotorg/dendrite:latest . +docker build -f build/docker/hub/Dockerfile -t matrixdotorg/dendrite:latest . -docker build -t matrixdotorg/dendrite:clientapi --build-arg component=dendrite-client-api-server -f docker/hub/Dockerfile.component . -docker build -t matrixdotorg/dendrite:clientproxy --build-arg component=client-api-proxy -f docker/hub/Dockerfile.component . -docker build -t matrixdotorg/dendrite:eduserver --build-arg component=dendrite-edu-server -f docker/hub/Dockerfile.component . -docker build -t matrixdotorg/dendrite:federationapi --build-arg component=dendrite-federation-api-server -f docker/hub/Dockerfile.component . -docker build -t matrixdotorg/dendrite:federationsender --build-arg component=dendrite-federation-sender-server -f docker/hub/Dockerfile.component . -docker build -t matrixdotorg/dendrite:federationproxy --build-arg component=federation-api-proxy -f docker/hub/Dockerfile.component . -docker build -t matrixdotorg/dendrite:keyserver --build-arg component=dendrite-key-server -f docker/hub/Dockerfile.component . -docker build -t matrixdotorg/dendrite:mediaapi --build-arg component=dendrite-media-api-server -f docker/hub/Dockerfile.component . -docker build -t matrixdotorg/dendrite:publicroomsapi --build-arg component=dendrite-public-rooms-api-server -f docker/hub/Dockerfile.component . -docker build -t matrixdotorg/dendrite:roomserver --build-arg component=dendrite-room-server -f docker/hub/Dockerfile.component . -docker build -t matrixdotorg/dendrite:syncapi --build-arg component=dendrite-sync-api-server -f docker/hub/Dockerfile.component . +docker build -t matrixdotorg/dendrite:clientapi --build-arg component=dendrite-client-api-server -f build/docker/hub/Dockerfile.component . +docker build -t matrixdotorg/dendrite:clientproxy --build-arg component=client-api-proxy -f build/docker/hub/Dockerfile.component . +docker build -t matrixdotorg/dendrite:eduserver --build-arg component=dendrite-edu-server -f build/docker/hub/Dockerfile.component . +docker build -t matrixdotorg/dendrite:federationapi --build-arg component=dendrite-federation-api-server -f build/docker/hub/Dockerfile.component . +docker build -t matrixdotorg/dendrite:federationsender --build-arg component=dendrite-federation-sender-server -f build/docker/hub/Dockerfile.component . +docker build -t matrixdotorg/dendrite:federationproxy --build-arg component=federation-api-proxy -f build/docker/hub/Dockerfile.component . +docker build -t matrixdotorg/dendrite:keyserver --build-arg component=dendrite-key-server -f build/docker/hub/Dockerfile.component . +docker build -t matrixdotorg/dendrite:mediaapi --build-arg component=dendrite-media-api-server -f build/docker/hub/Dockerfile.component . +docker build -t matrixdotorg/dendrite:publicroomsapi --build-arg component=dendrite-public-rooms-api-server -f build/docker/hub/Dockerfile.component . +docker build -t matrixdotorg/dendrite:roomserver --build-arg component=dendrite-room-server -f build/docker/hub/Dockerfile.component . +docker build -t matrixdotorg/dendrite:syncapi --build-arg component=dendrite-sync-api-server -f build/docker/hub/Dockerfile.component . +docker build -t matrixdotorg/dendrite:serverkeyapi --build-arg component=dendrite-server-key-api-server -f build/docker/hub/Dockerfile.component . diff --git a/build/scripts/build-test-lint.sh b/build/scripts/build-test-lint.sh index d2b2b4b16..4b18ca2f8 100755 --- a/build/scripts/build-test-lint.sh +++ b/build/scripts/build-test-lint.sh @@ -13,4 +13,4 @@ go build ./cmd/... ./scripts/find-lint.sh echo "Testing..." -go test ./... +go test -v ./... diff --git a/clientapi/auth/storage/accounts/interface.go b/clientapi/auth/storage/accounts/interface.go index a860f809d..4d1941a23 100644 --- a/clientapi/auth/storage/accounts/interface.go +++ b/clientapi/auth/storage/accounts/interface.go @@ -29,6 +29,9 @@ type Database interface { GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error SetDisplayName(ctx context.Context, localpart string, displayName string) error + // CreateAccount makes a new account with the given login name and password, and creates an empty profile + // for this account. If no password is supplied, the account will be a passwordless account. If the + // account already exists, it will return nil, ErrUserExists. CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*authtypes.Account, error) CreateGuestAccount(ctx context.Context) (*authtypes.Account, error) UpdateMemberships(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error diff --git a/clientapi/auth/storage/accounts/postgres/storage.go b/clientapi/auth/storage/accounts/postgres/storage.go index 4a1832674..4be5dca94 100644 --- a/clientapi/auth/storage/accounts/postgres/storage.go +++ b/clientapi/auth/storage/accounts/postgres/storage.go @@ -138,7 +138,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Accou // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the -// account already exists, it will return nil, nil. +// account already exists, it will return nil, ErrUserExists. func (d *Database) CreateAccount( ctx context.Context, localpart, plaintextPassword, appserviceID string, ) (acc *authtypes.Account, err error) { @@ -164,7 +164,7 @@ func (d *Database) createAccount( } if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil { if internal.IsUniqueConstraintViolationErr(err) { - return nil, nil + return nil, internal.ErrUserExists } return nil, err } diff --git a/roomserver/storage/sqlite3/prepare.go b/clientapi/auth/storage/accounts/sqlite3/constraint.go similarity index 50% rename from roomserver/storage/sqlite3/prepare.go rename to clientapi/auth/storage/accounts/sqlite3/constraint.go index 482dfa2b9..32f96c8e4 100644 --- a/roomserver/storage/sqlite3/prepare.go +++ b/clientapi/auth/storage/accounts/sqlite3/constraint.go @@ -1,5 +1,4 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,24 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. +// +build !wasm + package sqlite3 import ( - "database/sql" + "errors" + + "github.com/mattn/go-sqlite3" ) -// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement. -type statementList []struct { - statement **sql.Stmt - sql string -} - -// prepare the SQL for each statement in the list and assign the result to the prepared statement. -func (s statementList) prepare(db *sql.DB) (err error) { - for _, statement := range s { - if *statement.statement, err = db.Prepare(statement.sql); err != nil { - return - } - } - return +func isConstraintError(err error) bool { + return errors.Is(err, sqlite3.ErrConstraint) } diff --git a/clientapi/auth/storage/accounts/sqlite3/constraint_wasm.go b/clientapi/auth/storage/accounts/sqlite3/constraint_wasm.go new file mode 100644 index 000000000..b5077c070 --- /dev/null +++ b/clientapi/auth/storage/accounts/sqlite3/constraint_wasm.go @@ -0,0 +1,21 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build wasm + +package sqlite3 + +func isConstraintError(err error) bool { + return false +} \ No newline at end of file diff --git a/clientapi/auth/storage/accounts/sqlite3/storage.go b/clientapi/auth/storage/accounts/sqlite3/storage.go index 7dec87294..5233001f2 100644 --- a/clientapi/auth/storage/accounts/sqlite3/storage.go +++ b/clientapi/auth/storage/accounts/sqlite3/storage.go @@ -26,9 +26,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" "golang.org/x/crypto/bcrypt" - - // Import the postgres database driver. - _ "github.com/mattn/go-sqlite3" + // Import the sqlite3 database driver. ) // Database represents an account database @@ -148,7 +146,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Accou // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the -// account already exists, it will return nil, nil. +// account already exists, it will return nil, ErrUserExists. func (d *Database) CreateAccount( ctx context.Context, localpart, plaintextPassword, appserviceID string, ) (acc *authtypes.Account, err error) { @@ -172,8 +170,8 @@ func (d *Database) createAccount( } } if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil { - if internal.IsUniqueConstraintViolationErr(err) { - return nil, nil + if isConstraintError(err) { + return nil, internal.ErrUserExists } return nil, err } diff --git a/clientapi/auth/storage/accounts/storage_wasm.go b/clientapi/auth/storage/accounts/storage_wasm.go index 81e77cf79..7cf50de79 100644 --- a/clientapi/auth/storage/accounts/storage_wasm.go +++ b/clientapi/auth/storage/accounts/storage_wasm.go @@ -36,7 +36,7 @@ func NewDatabase( case "postgres": return nil, fmt.Errorf("Cannot use postgres implementation") case "file": - return sqlite3.NewDatabase(dataSourceName, serverName) + return sqlite3.NewDatabase(uri.Path, serverName) default: return nil, fmt.Errorf("Cannot use postgres implementation") } diff --git a/clientapi/auth/storage/devices/postgres/devices_table.go b/clientapi/auth/storage/devices/postgres/devices_table.go index 67d573f95..c84e83d3d 100644 --- a/clientapi/auth/storage/devices/postgres/devices_table.go +++ b/clientapi/auth/storage/devices/postgres/devices_table.go @@ -206,9 +206,8 @@ func (s *devicesStatements) selectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*authtypes.Device, error) { var dev authtypes.Device - var created sql.NullInt64 stmt := s.selectDeviceByIDStmt - err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&created) + err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName) if err == nil { dev.ID = deviceID dev.UserID = userutil.MakeUserID(localpart, s.serverName) @@ -230,10 +229,17 @@ func (s *devicesStatements) selectDevicesByLocalpart( for rows.Next() { var dev authtypes.Device - err = rows.Scan(&dev.ID, &dev.DisplayName) + var id, displayname sql.NullString + err = rows.Scan(&id, &displayname) if err != nil { return devices, err } + if id.Valid { + dev.ID = id.String + } + if displayname.Valid { + dev.DisplayName = displayname.String + } dev.UserID = userutil.MakeUserID(localpart, s.serverName) devices = append(devices, dev) } diff --git a/clientapi/auth/storage/devices/sqlite3/devices_table.go b/clientapi/auth/storage/devices/sqlite3/devices_table.go index 6ffc1646f..49f3eaed7 100644 --- a/clientapi/auth/storage/devices/sqlite3/devices_table.go +++ b/clientapi/auth/storage/devices/sqlite3/devices_table.go @@ -208,9 +208,8 @@ func (s *devicesStatements) selectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*authtypes.Device, error) { var dev authtypes.Device - var created sql.NullInt64 stmt := s.selectDeviceByIDStmt - err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&created) + err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName) if err == nil { dev.ID = deviceID dev.UserID = userutil.MakeUserID(localpart, s.serverName) @@ -231,10 +230,17 @@ func (s *devicesStatements) selectDevicesByLocalpart( for rows.Next() { var dev authtypes.Device - err = rows.Scan(&dev.ID, &dev.DisplayName) + var id, displayname sql.NullString + err = rows.Scan(&id, &displayname) if err != nil { return devices, err } + if id.Valid { + dev.ID = id.String + } + if displayname.Valid { + dev.DisplayName = displayname.String + } dev.UserID = userutil.MakeUserID(localpart, s.serverName) devices = append(devices, dev) } diff --git a/clientapi/auth/storage/devices/storage_wasm.go b/clientapi/auth/storage/devices/storage_wasm.go index 14c19e74b..c89ad887b 100644 --- a/clientapi/auth/storage/devices/storage_wasm.go +++ b/clientapi/auth/storage/devices/storage_wasm.go @@ -36,7 +36,7 @@ func NewDatabase( case "postgres": return nil, fmt.Errorf("Cannot use postgres implementation") case "file": - return sqlite3.NewDatabase(dataSourceName, serverName) + return sqlite3.NewDatabase(uri.Path, serverName) default: return nil, fmt.Errorf("Cannot use postgres implementation") } diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index 66317e92e..0ce44c211 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -65,7 +65,7 @@ func SetupClientAPIComponent( } routing.Setup( - base.APIMux, base.Cfg, roomserverProducer, rsAPI, asAPI, + base.PublicAPIMux, base.Cfg, roomserverProducer, rsAPI, asAPI, accountsDB, deviceDB, federation, *keyRing, userUpdateProducer, syncProducer, eduProducer, transactionsCache, fsAPI, ) diff --git a/clientapi/producers/eduserver.go b/clientapi/producers/eduserver.go index 30c40fb7f..102c1fad8 100644 --- a/clientapi/producers/eduserver.go +++ b/clientapi/producers/eduserver.go @@ -14,6 +14,7 @@ package producers import ( "context" + "encoding/json" "time" "github.com/matrix-org/dendrite/eduserver/api" @@ -52,3 +53,28 @@ func (p *EDUServerProducer) SendTyping( return err } + +// SendToDevice sends a typing event to EDU server +func (p *EDUServerProducer) SendToDevice( + ctx context.Context, sender, userID, deviceID, eventType string, + message interface{}, +) error { + js, err := json.Marshal(message) + if err != nil { + return err + } + requestData := api.InputSendToDeviceEvent{ + UserID: userID, + DeviceID: deviceID, + SendToDeviceEvent: gomatrixserverlib.SendToDeviceEvent{ + Sender: sender, + Type: eventType, + Content: js, + }, + } + request := api.InputSendToDeviceEventRequest{ + InputSendToDeviceEvent: requestData, + } + response := api.InputSendToDeviceEventResponse{} + return p.InputAPI.InputSendToDeviceEvent(ctx, &request, &response) +} diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index b1dfb56fb..3d4b5f5bc 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -77,7 +77,7 @@ func DirectoryRoom( if fedErr != nil { // TODO: Return 502 if the remote server errored. // TODO: Return 504 if the remote server timed out. - util.GetLogger(req.Context()).WithError(err).Error("federation.LookupRoomAlias failed") + util.GetLogger(req.Context()).WithError(fedErr).Error("federation.LookupRoomAlias failed") return jsonerror.InternalServerError() } res.RoomID = fedRes.RoomID diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index 500df337a..a3d676532 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -75,6 +75,6 @@ func JoinRoomByIDOrAlias( // TODO: Put the response struct somewhere internal. JSON: struct { RoomID string `json:"room_id"` - }{joinReq.RoomIDOrAlias}, + }{joinRes.RoomID}, } } diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 40a2862f8..d356db2cb 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -830,15 +830,16 @@ func completeRegistration( acc, err := accountDB.CreateAccount(ctx, username, password, appserviceID) if err != nil { + if errors.Is(err, internal.ErrUserExists) { // user already exists + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.UserInUse("Desired user ID is already taken."), + } + } return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: jsonerror.Unknown("failed to create account: " + err.Error()), } - } else if acc == nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.UserInUse("Desired user ID is already taken."), - } } // Increment prometheus counter for created users diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 377881cb6..83e399ac3 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -36,9 +36,9 @@ import ( "github.com/matrix-org/util" ) -const pathPrefixV1 = "/_matrix/client/api/v1" -const pathPrefixR0 = "/_matrix/client/r0" -const pathPrefixUnstable = "/_matrix/client/unstable" +const pathPrefixV1 = "/client/api/v1" +const pathPrefixR0 = "/client/r0" +const pathPrefixUnstable = "/client/unstable" // Setup registers HTTP handlers with the given ServeMux. It also supplies the given http.Client // to clients which need to make outbound HTTP requests. @@ -47,7 +47,7 @@ const pathPrefixUnstable = "/_matrix/client/unstable" // applied: // nolint: gocyclo func Setup( - apiMux *mux.Router, cfg *config.Dendrite, + publicAPIMux *mux.Router, cfg *config.Dendrite, producer *producers.RoomserverProducer, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, @@ -62,7 +62,7 @@ func Setup( federationSender federationSenderAPI.FederationSenderInternalAPI, ) { - apiMux.Handle("/_matrix/client/versions", + publicAPIMux.Handle("/client/versions", internal.MakeExternalAPI("versions", func(req *http.Request) util.JSONResponse { return util.JSONResponse{ Code: http.StatusOK, @@ -78,9 +78,9 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter() - v1mux := apiMux.PathPrefix(pathPrefixV1).Subrouter() - unstableMux := apiMux.PathPrefix(pathPrefixUnstable).Subrouter() + r0mux := publicAPIMux.PathPrefix(pathPrefixR0).Subrouter() + v1mux := publicAPIMux.PathPrefix(pathPrefixV1).Subrouter() + unstableMux := publicAPIMux.PathPrefix(pathPrefixUnstable).Subrouter() authData := auth.Data{ AccountDB: accountDB, @@ -274,6 +274,31 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) + r0mux.Handle("/sendToDevice/{eventType}/{txnID}", + internal.MakeAuthAPI("send_to_device", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + txnID := vars["txnID"] + return SendToDevice(req, device, eduProducer, transactionsCache, vars["eventType"], &txnID) + }), + ).Methods(http.MethodPut, http.MethodOptions) + + // This is only here because sytest refers to /unstable for this endpoint + // rather than r0. It's an exact duplicate of the above handler. + // TODO: Remove this if/when sytest is fixed! + unstableMux.Handle("/sendToDevice/{eventType}/{txnID}", + internal.MakeAuthAPI("send_to_device", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + txnID := vars["txnID"] + return SendToDevice(req, device, eduProducer, transactionsCache, vars["eventType"], &txnID) + }), + ).Methods(http.MethodPut, http.MethodOptions) + r0mux.Handle("/account/whoami", internal.MakeAuthAPI("whoami", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { return Whoami(req, device) diff --git a/clientapi/routing/sendtodevice.go b/clientapi/routing/sendtodevice.go new file mode 100644 index 000000000..5d3060d74 --- /dev/null +++ b/clientapi/routing/sendtodevice.go @@ -0,0 +1,70 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "encoding/json" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/clientapi/producers" + "github.com/matrix-org/dendrite/internal/transactions" + "github.com/matrix-org/util" +) + +// SendToDevice handles PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId} +// sends the device events to the EDU Server +func SendToDevice( + req *http.Request, device *authtypes.Device, + eduProducer *producers.EDUServerProducer, + txnCache *transactions.Cache, + eventType string, txnID *string, +) util.JSONResponse { + if txnID != nil { + if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok { + return *res + } + } + + var httpReq struct { + Messages map[string]map[string]json.RawMessage `json:"messages"` + } + resErr := httputil.UnmarshalJSONRequest(req, &httpReq) + if resErr != nil { + return *resErr + } + + for userID, byUser := range httpReq.Messages { + for deviceID, message := range byUser { + if err := eduProducer.SendToDevice( + req.Context(), device.UserID, userID, deviceID, eventType, message, + ); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("eduProducer.SendToDevice failed") + return jsonerror.InternalServerError() + } + } + } + + res := util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } + + if txnID != nil { + txnCache.AddTransaction(device.AccessToken, *txnID, &res) + } + + return res +} diff --git a/cmd/client-api-proxy/main.go b/cmd/client-api-proxy/main.go index 27991c109..979b0b042 100644 --- a/cmd/client-api-proxy/main.go +++ b/cmd/client-api-proxy/main.go @@ -75,7 +75,6 @@ func makeProxy(targetURL string) (*httputil.ReverseProxy, error) { // Pratically this means that any distinction between '%2F' and '/' // in the URL will be lost by the time it reaches the target. path := req.URL.Path - path = "api" + path log.WithFields(log.Fields{ "path": path, "url": targetURL, diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index eca9b2fe6..9c1e45932 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -69,13 +69,10 @@ func main() { os.Exit(1) } - account, err := accountDB.CreateAccount(context.Background(), *username, *password, "") + _, err = accountDB.CreateAccount(context.Background(), *username, *password, "") if err != nil { fmt.Println(err.Error()) os.Exit(1) - } else if account == nil { - fmt.Println("Username already exists") - os.Exit(1) } deviceDB, err := devices.NewDatabase(*database, nil, serverName) diff --git a/cmd/dendrite-client-api-server/main.go b/cmd/dendrite-client-api-server/main.go index 0815be799..f919243de 100644 --- a/cmd/dendrite-client-api-server/main.go +++ b/cmd/dendrite-client-api-server/main.go @@ -19,7 +19,6 @@ import ( "github.com/matrix-org/dendrite/eduserver" "github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/internal/basecomponent" - "github.com/matrix-org/dendrite/internal/keydb" "github.com/matrix-org/dendrite/internal/transactions" ) @@ -31,18 +30,19 @@ func main() { accountDB := base.CreateAccountsDB() deviceDB := base.CreateDeviceDB() - keyDB := base.CreateKeyDB() federation := base.CreateFederationClient() - keyRing := keydb.CreateKeyRing(federation.Client, keyDB, cfg.Matrix.KeyPerspectives) + + serverKeyAPI := base.CreateHTTPServerKeyAPIs() + keyRing := serverKeyAPI.KeyRing() asQuery := base.CreateHTTPAppServiceAPIs() rsAPI := base.CreateHTTPRoomserverAPIs() fsAPI := base.CreateHTTPFederationSenderAPIs() rsAPI.SetFederationSenderAPI(fsAPI) - eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New()) + eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New(), deviceDB) clientapi.SetupClientAPIComponent( - base, deviceDB, accountDB, federation, &keyRing, + base, deviceDB, accountDB, federation, keyRing, rsAPI, eduInputAPI, asQuery, transactions.New(), fsAPI, ) diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go index a04a28e06..c00316ec7 100644 --- a/cmd/dendrite-demo-libp2p/main.go +++ b/cmd/dendrite-demo-libp2p/main.go @@ -37,33 +37,23 @@ import ( "github.com/matrix-org/dendrite/federationsender" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/config" - "github.com/matrix-org/dendrite/internal/keydb" "github.com/matrix-org/dendrite/internal/transactions" "github.com/matrix-org/dendrite/mediaapi" "github.com/matrix-org/dendrite/publicroomsapi" "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/serverkeyapi" "github.com/matrix-org/dendrite/syncapi" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/eduserver/cache" - "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/sirupsen/logrus" ) func createKeyDB( base *P2PDendrite, -) keydb.Database { - db, err := keydb.NewDatabase( - string(base.Base.Cfg.Database.ServerKey), - base.Base.Cfg.DbProperties(), - base.Base.Cfg.Matrix.ServerName, - base.Base.Cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey), - base.Base.Cfg.Matrix.KeyID, - ) - if err != nil { - logrus.WithError(err).Panicf("failed to connect to keys db") - } + db gomatrixserverlib.KeyDatabase, +) { mdns := mDNSListener{ host: base.LibP2P, keydb: db, @@ -78,7 +68,6 @@ func createKeyDB( panic(err) } serv.RegisterNotifee(&mdns) - return db } func createFederationClient( @@ -145,45 +134,52 @@ func main() { accountDB := base.Base.CreateAccountsDB() deviceDB := base.Base.CreateDeviceDB() - keyDB := createKeyDB(base) federation := createFederationClient(base) - keyRing := keydb.CreateKeyRing(federation.Client, keyDB, cfg.Matrix.KeyPerspectives) + + serverKeyAPI := serverkeyapi.SetupServerKeyAPIComponent( + &base.Base, federation, + ) + keyRing := serverKeyAPI.KeyRing() + createKeyDB( + base, serverKeyAPI, + ) rsAPI := roomserver.SetupRoomServerComponent( &base.Base, keyRing, federation, ) eduInputAPI := eduserver.SetupEDUServerComponent( - &base.Base, cache.New(), + &base.Base, cache.New(), deviceDB, ) asAPI := appservice.SetupAppServiceAPIComponent( &base.Base, accountDB, deviceDB, federation, rsAPI, transactions.New(), ) fsAPI := federationsender.SetupFederationSenderComponent( - &base.Base, federation, rsAPI, &keyRing, + &base.Base, federation, rsAPI, keyRing, ) rsAPI.SetFederationSenderAPI(fsAPI) clientapi.SetupClientAPIComponent( &base.Base, deviceDB, accountDB, - federation, &keyRing, rsAPI, + federation, keyRing, rsAPI, eduInputAPI, asAPI, transactions.New(), fsAPI, ) eduProducer := producers.NewEDUServerProducer(eduInputAPI) - federationapi.SetupFederationAPIComponent(&base.Base, accountDB, deviceDB, federation, &keyRing, rsAPI, asAPI, fsAPI, eduProducer) + federationapi.SetupFederationAPIComponent(&base.Base, accountDB, deviceDB, federation, keyRing, rsAPI, asAPI, fsAPI, eduProducer) mediaapi.SetupMediaAPIComponent(&base.Base, deviceDB) - publicRoomsDB, err := storage.NewPublicRoomsServerDatabaseWithPubSub(string(base.Base.Cfg.Database.PublicRoomsAPI), base.LibP2PPubsub) + publicRoomsDB, err := storage.NewPublicRoomsServerDatabaseWithPubSub(string(base.Base.Cfg.Database.PublicRoomsAPI), base.LibP2PPubsub, cfg.Matrix.ServerName) if err != nil { logrus.WithError(err).Panicf("failed to connect to public rooms db") } publicroomsapi.SetupPublicRoomsAPIComponent(&base.Base, deviceDB, publicRoomsDB, rsAPI, federation, nil) // Check this later syncapi.SetupSyncAPIComponent(&base.Base, deviceDB, accountDB, rsAPI, federation, &cfg) - httpHandler := internal.WrapHandlerInCORS(base.Base.APIMux) - - // Set up the API endpoints we handle. /metrics is for prometheus, and is - // not wrapped by CORS, while everything else is - http.Handle("/metrics", promhttp.Handler()) - http.Handle("/", httpHandler) + internal.SetupHTTPAPI( + http.DefaultServeMux, + base.Base.PublicAPIMux, + base.Base.InternalAPIMux, + &cfg, + base.Base.EnableHTTPAPIs, + ) // Expose the matrix APIs directly rather than putting them under a /api path. go func() { diff --git a/cmd/dendrite-demo-libp2p/mdnslistener.go b/cmd/dendrite-demo-libp2p/mdnslistener.go index c05c94d59..01705cf39 100644 --- a/cmd/dendrite-demo-libp2p/mdnslistener.go +++ b/cmd/dendrite-demo-libp2p/mdnslistener.go @@ -21,12 +21,11 @@ import ( "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/peer" - "github.com/matrix-org/dendrite/internal/keydb" "github.com/matrix-org/gomatrixserverlib" ) type mDNSListener struct { - keydb keydb.Database + keydb gomatrixserverlib.KeyDatabase host host.Host } diff --git a/cmd/dendrite-demo-libp2p/storage/postgreswithdht/storage.go b/cmd/dendrite-demo-libp2p/storage/postgreswithdht/storage.go index cd2804c97..d2cb36a8b 100644 --- a/cmd/dendrite-demo-libp2p/storage/postgreswithdht/storage.go +++ b/cmd/dendrite-demo-libp2p/storage/postgreswithdht/storage.go @@ -44,8 +44,8 @@ type PublicRoomsServerDatabase struct { } // NewPublicRoomsServerDatabase creates a new public rooms server database. -func NewPublicRoomsServerDatabase(dataSourceName string, dht *dht.IpfsDHT) (*PublicRoomsServerDatabase, error) { - pg, err := postgres.NewPublicRoomsServerDatabase(dataSourceName, nil) +func NewPublicRoomsServerDatabase(dataSourceName string, dht *dht.IpfsDHT, localServerName gomatrixserverlib.ServerName) (*PublicRoomsServerDatabase, error) { + pg, err := postgres.NewPublicRoomsServerDatabase(dataSourceName, nil, localServerName) if err != nil { return nil, err } diff --git a/cmd/dendrite-demo-libp2p/storage/postgreswithpubsub/storage.go b/cmd/dendrite-demo-libp2p/storage/postgreswithpubsub/storage.go index e4372c64f..cf642eb38 100644 --- a/cmd/dendrite-demo-libp2p/storage/postgreswithpubsub/storage.go +++ b/cmd/dendrite-demo-libp2p/storage/postgreswithpubsub/storage.go @@ -47,8 +47,8 @@ type PublicRoomsServerDatabase struct { } // NewPublicRoomsServerDatabase creates a new public rooms server database. -func NewPublicRoomsServerDatabase(dataSourceName string, pubsub *pubsub.PubSub) (*PublicRoomsServerDatabase, error) { - pg, err := postgres.NewPublicRoomsServerDatabase(dataSourceName, nil) +func NewPublicRoomsServerDatabase(dataSourceName string, pubsub *pubsub.PubSub, localServerName gomatrixserverlib.ServerName) (*PublicRoomsServerDatabase, error) { + pg, err := postgres.NewPublicRoomsServerDatabase(dataSourceName, nil, localServerName) if err != nil { return nil, err } diff --git a/cmd/dendrite-demo-libp2p/storage/storage.go b/cmd/dendrite-demo-libp2p/storage/storage.go index 668edbaa3..2d8dc1817 100644 --- a/cmd/dendrite-demo-libp2p/storage/storage.go +++ b/cmd/dendrite-demo-libp2p/storage/storage.go @@ -23,39 +23,40 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-libp2p/storage/postgreswithpubsub" "github.com/matrix-org/dendrite/publicroomsapi/storage" "github.com/matrix-org/dendrite/publicroomsapi/storage/sqlite3" + "github.com/matrix-org/gomatrixserverlib" ) const schemePostgres = "postgres" const schemeFile = "file" // NewPublicRoomsServerDatabase opens a database connection. -func NewPublicRoomsServerDatabaseWithDHT(dataSourceName string, dht *dht.IpfsDHT) (storage.Database, error) { +func NewPublicRoomsServerDatabaseWithDHT(dataSourceName string, dht *dht.IpfsDHT, localServerName gomatrixserverlib.ServerName) (storage.Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { - return postgreswithdht.NewPublicRoomsServerDatabase(dataSourceName, dht) + return postgreswithdht.NewPublicRoomsServerDatabase(dataSourceName, dht, localServerName) } switch uri.Scheme { case schemePostgres: - return postgreswithdht.NewPublicRoomsServerDatabase(dataSourceName, dht) + return postgreswithdht.NewPublicRoomsServerDatabase(dataSourceName, dht, localServerName) case schemeFile: - return sqlite3.NewPublicRoomsServerDatabase(dataSourceName) + return sqlite3.NewPublicRoomsServerDatabase(dataSourceName, localServerName) default: - return postgreswithdht.NewPublicRoomsServerDatabase(dataSourceName, dht) + return postgreswithdht.NewPublicRoomsServerDatabase(dataSourceName, dht, localServerName) } } // NewPublicRoomsServerDatabase opens a database connection. -func NewPublicRoomsServerDatabaseWithPubSub(dataSourceName string, pubsub *pubsub.PubSub) (storage.Database, error) { +func NewPublicRoomsServerDatabaseWithPubSub(dataSourceName string, pubsub *pubsub.PubSub, localServerName gomatrixserverlib.ServerName) (storage.Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { - return postgreswithpubsub.NewPublicRoomsServerDatabase(dataSourceName, pubsub) + return postgreswithpubsub.NewPublicRoomsServerDatabase(dataSourceName, pubsub, localServerName) } switch uri.Scheme { case schemePostgres: - return postgreswithpubsub.NewPublicRoomsServerDatabase(dataSourceName, pubsub) + return postgreswithpubsub.NewPublicRoomsServerDatabase(dataSourceName, pubsub, localServerName) case schemeFile: - return sqlite3.NewPublicRoomsServerDatabase(dataSourceName) + return sqlite3.NewPublicRoomsServerDatabase(dataSourceName, localServerName) default: - return postgreswithpubsub.NewPublicRoomsServerDatabase(dataSourceName, pubsub) + return postgreswithpubsub.NewPublicRoomsServerDatabase(dataSourceName, pubsub, localServerName) } } diff --git a/cmd/dendrite-edu-server/main.go b/cmd/dendrite-edu-server/main.go index 66e17e577..ca0460f8d 100644 --- a/cmd/dendrite-edu-server/main.go +++ b/cmd/dendrite-edu-server/main.go @@ -29,8 +29,9 @@ func main() { logrus.WithError(err).Warn("BaseDendrite close failed") } }() + deviceDB := base.CreateDeviceDB() - eduserver.SetupEDUServerComponent(base, cache.New()) + eduserver.SetupEDUServerComponent(base, cache.New(), deviceDB) base.SetupAndServeHTTP(string(base.Cfg.Bind.EDUServer), string(base.Cfg.Listen.EDUServer)) diff --git a/cmd/dendrite-federation-api-server/main.go b/cmd/dendrite-federation-api-server/main.go index be97bc9a0..af63b549a 100644 --- a/cmd/dendrite-federation-api-server/main.go +++ b/cmd/dendrite-federation-api-server/main.go @@ -20,7 +20,6 @@ import ( "github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/internal/basecomponent" - "github.com/matrix-org/dendrite/internal/keydb" ) func main() { @@ -30,19 +29,21 @@ func main() { accountDB := base.CreateAccountsDB() deviceDB := base.CreateDeviceDB() - keyDB := base.CreateKeyDB() federation := base.CreateFederationClient() + + serverKeyAPI := base.CreateHTTPServerKeyAPIs() + keyRing := serverKeyAPI.KeyRing() + fsAPI := base.CreateHTTPFederationSenderAPIs() - keyRing := keydb.CreateKeyRing(federation.Client, keyDB, cfg.Matrix.KeyPerspectives) rsAPI := base.CreateHTTPRoomserverAPIs() asAPI := base.CreateHTTPAppServiceAPIs() rsAPI.SetFederationSenderAPI(fsAPI) - eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New()) + eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New(), deviceDB) eduProducer := producers.NewEDUServerProducer(eduInputAPI) federationapi.SetupFederationAPIComponent( - base, accountDB, deviceDB, federation, &keyRing, + base, accountDB, deviceDB, federation, keyRing, rsAPI, asAPI, fsAPI, eduProducer, ) diff --git a/cmd/dendrite-federation-sender-server/main.go b/cmd/dendrite-federation-sender-server/main.go index b6af8deb7..2b1e0ae8f 100644 --- a/cmd/dendrite-federation-sender-server/main.go +++ b/cmd/dendrite-federation-sender-server/main.go @@ -17,7 +17,6 @@ package main import ( "github.com/matrix-org/dendrite/federationsender" "github.com/matrix-org/dendrite/internal/basecomponent" - "github.com/matrix-org/dendrite/internal/keydb" ) func main() { @@ -26,11 +25,13 @@ func main() { defer base.Close() // nolint: errcheck federation := base.CreateFederationClient() - keyDB := base.CreateKeyDB() - keyRing := keydb.CreateKeyRing(federation.Client, keyDB, cfg.Matrix.KeyPerspectives) + + serverKeyAPI := base.CreateHTTPServerKeyAPIs() + keyRing := serverKeyAPI.KeyRing() + rsAPI := base.CreateHTTPRoomserverAPIs() fsAPI := federationsender.SetupFederationSenderComponent( - base, federation, rsAPI, &keyRing, + base, federation, rsAPI, keyRing, ) rsAPI.SetFederationSenderAPI(fsAPI) diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index 6d3a479ac..41907cc0c 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -27,15 +27,15 @@ import ( "github.com/matrix-org/dendrite/federationsender" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/basecomponent" - "github.com/matrix-org/dendrite/internal/keydb" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/transactions" "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/mediaapi" "github.com/matrix-org/dendrite/publicroomsapi" "github.com/matrix-org/dendrite/publicroomsapi/storage" "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/serverkeyapi" "github.com/matrix-org/dendrite/syncapi" - "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/sirupsen/logrus" ) @@ -45,60 +45,95 @@ var ( httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server") certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS") keyFile = flag.String("tls-key", "", "The PEM private key to use for TLS") - enableHTTPAPIs = flag.Bool("api", false, "Expose internal HTTP APIs in monolith mode") + enableHTTPAPIs = flag.Bool("api", false, "Use HTTP APIs instead of short-circuiting (warning: exposes API endpoints!)") ) func main() { cfg := basecomponent.ParseMonolithFlags() + if *enableHTTPAPIs { + // If the HTTP APIs are enabled then we need to update the Listen + // statements in the configuration so that we know where to find + // the API endpoints. They'll listen on the same port as the monolith + // itself. + addr := config.Address(*httpBindAddr) + cfg.Listen.RoomServer = addr + cfg.Listen.EDUServer = addr + cfg.Listen.AppServiceAPI = addr + cfg.Listen.FederationSender = addr + cfg.Listen.ServerKeyAPI = addr + } + base := basecomponent.NewBaseDendrite(cfg, "Monolith", *enableHTTPAPIs) defer base.Close() // nolint: errcheck accountDB := base.CreateAccountsDB() deviceDB := base.CreateDeviceDB() - keyDB := base.CreateKeyDB() federation := base.CreateFederationClient() - keyRing := keydb.CreateKeyRing(federation.Client, keyDB, cfg.Matrix.KeyPerspectives) - rsAPI := roomserver.SetupRoomServerComponent( + serverKeyAPI := serverkeyapi.SetupServerKeyAPIComponent( + base, federation, + ) + if base.EnableHTTPAPIs { + serverKeyAPI = base.CreateHTTPServerKeyAPIs() + } + keyRing := serverKeyAPI.KeyRing() + + rsComponent := roomserver.SetupRoomServerComponent( base, keyRing, federation, ) + rsAPI := rsComponent + if base.EnableHTTPAPIs { + rsAPI = base.CreateHTTPRoomserverAPIs() + } + eduInputAPI := eduserver.SetupEDUServerComponent( - base, cache.New(), + base, cache.New(), deviceDB, ) + if base.EnableHTTPAPIs { + eduInputAPI = base.CreateHTTPEDUServerAPIs() + } + asAPI := appservice.SetupAppServiceAPIComponent( base, accountDB, deviceDB, federation, rsAPI, transactions.New(), ) + if base.EnableHTTPAPIs { + asAPI = base.CreateHTTPAppServiceAPIs() + } + fsAPI := federationsender.SetupFederationSenderComponent( - base, federation, rsAPI, &keyRing, + base, federation, rsAPI, keyRing, ) - rsAPI.SetFederationSenderAPI(fsAPI) + if base.EnableHTTPAPIs { + fsAPI = base.CreateHTTPFederationSenderAPIs() + } + rsComponent.SetFederationSenderAPI(fsAPI) clientapi.SetupClientAPIComponent( base, deviceDB, accountDB, - federation, &keyRing, rsAPI, + federation, keyRing, rsAPI, eduInputAPI, asAPI, transactions.New(), fsAPI, ) + keyserver.SetupKeyServerComponent( base, deviceDB, accountDB, ) eduProducer := producers.NewEDUServerProducer(eduInputAPI) - federationapi.SetupFederationAPIComponent(base, accountDB, deviceDB, federation, &keyRing, rsAPI, asAPI, fsAPI, eduProducer) + federationapi.SetupFederationAPIComponent(base, accountDB, deviceDB, federation, keyRing, rsAPI, asAPI, fsAPI, eduProducer) mediaapi.SetupMediaAPIComponent(base, deviceDB) - publicRoomsDB, err := storage.NewPublicRoomsServerDatabase(string(base.Cfg.Database.PublicRoomsAPI), base.Cfg.DbProperties()) + publicRoomsDB, err := storage.NewPublicRoomsServerDatabase(string(base.Cfg.Database.PublicRoomsAPI), base.Cfg.DbProperties(), cfg.Matrix.ServerName) if err != nil { logrus.WithError(err).Panicf("failed to connect to public rooms db") } publicroomsapi.SetupPublicRoomsAPIComponent(base, deviceDB, publicRoomsDB, rsAPI, federation, nil) syncapi.SetupSyncAPIComponent(base, deviceDB, accountDB, rsAPI, federation, cfg) - httpHandler := internal.WrapHandlerInCORS(base.APIMux) - - // Set up the API endpoints we handle. /metrics is for prometheus, and is - // not wrapped by CORS, while everything else is - if cfg.Metrics.Enabled { - http.Handle("/metrics", internal.WrapHandlerInBasicAuth(promhttp.Handler(), cfg.Metrics.BasicAuth)) - } - http.Handle("/", httpHandler) + internal.SetupHTTPAPI( + http.DefaultServeMux, + base.PublicAPIMux, + base.InternalAPIMux, + cfg, + base.EnableHTTPAPIs, + ) // Expose the matrix APIs directly rather than putting them under a /api path. go func() { diff --git a/cmd/dendrite-public-rooms-api-server/main.go b/cmd/dendrite-public-rooms-api-server/main.go index cf9d09c11..d506e32d1 100644 --- a/cmd/dendrite-public-rooms-api-server/main.go +++ b/cmd/dendrite-public-rooms-api-server/main.go @@ -32,7 +32,7 @@ func main() { rsAPI := base.CreateHTTPRoomserverAPIs() rsAPI.SetFederationSenderAPI(fsAPI) - publicRoomsDB, err := storage.NewPublicRoomsServerDatabase(string(base.Cfg.Database.PublicRoomsAPI), base.Cfg.DbProperties()) + publicRoomsDB, err := storage.NewPublicRoomsServerDatabase(string(base.Cfg.Database.PublicRoomsAPI), base.Cfg.DbProperties(), cfg.Matrix.ServerName) if err != nil { logrus.WithError(err).Panicf("failed to connect to public rooms db") } diff --git a/cmd/dendrite-room-server/main.go b/cmd/dendrite-room-server/main.go index 31cf4a042..b86b74972 100644 --- a/cmd/dendrite-room-server/main.go +++ b/cmd/dendrite-room-server/main.go @@ -16,7 +16,6 @@ package main import ( "github.com/matrix-org/dendrite/internal/basecomponent" - "github.com/matrix-org/dendrite/internal/keydb" "github.com/matrix-org/dendrite/roomserver" ) @@ -24,9 +23,10 @@ func main() { cfg := basecomponent.ParseFlags() base := basecomponent.NewBaseDendrite(cfg, "RoomServerAPI", true) defer base.Close() // nolint: errcheck - keyDB := base.CreateKeyDB() federation := base.CreateFederationClient() - keyRing := keydb.CreateKeyRing(federation.Client, keyDB, cfg.Matrix.KeyPerspectives) + + serverKeyAPI := base.CreateHTTPServerKeyAPIs() + keyRing := serverKeyAPI.KeyRing() fsAPI := base.CreateHTTPFederationSenderAPIs() rsAPI := roomserver.SetupRoomServerComponent(base, keyRing, federation) diff --git a/cmd/dendrite-server-key-api-server/main.go b/cmd/dendrite-server-key-api-server/main.go new file mode 100644 index 000000000..075734125 --- /dev/null +++ b/cmd/dendrite-server-key-api-server/main.go @@ -0,0 +1,32 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "github.com/matrix-org/dendrite/internal/basecomponent" + "github.com/matrix-org/dendrite/serverkeyapi" +) + +func main() { + cfg := basecomponent.ParseFlags() + base := basecomponent.NewBaseDendrite(cfg, "ServerKeyAPI", true) + defer base.Close() // nolint: errcheck + + federation := base.CreateFederationClient() + + serverkeyapi.SetupServerKeyAPIComponent(base, federation) + + base.SetupAndServeHTTP(string(base.Cfg.Bind.ServerKeyAPI), string(base.Cfg.Listen.ServerKeyAPI)) +} diff --git a/cmd/dendritejs/jsServer.go b/cmd/dendritejs/jsServer.go index a5ac574d8..31f122645 100644 --- a/cmd/dendritejs/jsServer.go +++ b/cmd/dendritejs/jsServer.go @@ -49,15 +49,11 @@ func (h *JSServer) OnRequestFromJS(this js.Value, args []js.Value) interface{} { // we need to put this in an immediately invoked goroutine. go func() { resolve := pargs[0] - fmt.Println("Received request:") - fmt.Printf("%s\n", httpStr) resStr, err := h.handle(httpStr) errStr := "" if err != nil { errStr = err.Error() } - fmt.Println("Sending response:") - fmt.Printf("%s\n", resStr) resolve.Invoke(map[string]interface{}{ "result": resStr, "error": errStr, diff --git a/cmd/dendritejs/keyfetcher.go b/cmd/dendritejs/keyfetcher.go index ee4905d4f..0073b7e6b 100644 --- a/cmd/dendritejs/keyfetcher.go +++ b/cmd/dendritejs/keyfetcher.go @@ -23,7 +23,6 @@ import ( "github.com/libp2p/go-libp2p-core/peer" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" ) const libp2pMatrixKeyID = "ed25519:libp2p-dendrite" @@ -63,8 +62,6 @@ func (f *libp2pKeyFetcher) FetchKeys( if err != nil { return nil, fmt.Errorf("Failed to extract raw bytes from public key: %w", err) } - util.GetLogger(ctx).Info("libp2pKeyFetcher.FetchKeys: Using public key %v for server name %s", pubKeyBytes, req.ServerName) - b64Key := gomatrixserverlib.Base64String(pubKeyBytes) res[req] = gomatrixserverlib.PublicKeyLookupResult{ VerifyKey: gomatrixserverlib.VerifyKey{ @@ -82,3 +79,8 @@ func (f *libp2pKeyFetcher) FetchKeys( func (f *libp2pKeyFetcher) FetcherName() string { return "libp2pKeyFetcher" } + +// no-op function for storing keys - we don't do any work to fetch them so don't bother storing. +func (f *libp2pKeyFetcher) StoreKeys(ctx context.Context, results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) error { + return nil +} diff --git a/cmd/dendritejs/main.go b/cmd/dendritejs/main.go index 2609c74cd..c1aef44d9 100644 --- a/cmd/dendritejs/main.go +++ b/cmd/dendritejs/main.go @@ -39,6 +39,7 @@ import ( "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/syncapi" go_http_js_libp2p "github.com/matrix-org/go-http-js-libp2p" + "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" @@ -163,18 +164,19 @@ func main() { cfg := &config.Dendrite{} cfg.SetDefaults() cfg.Kafka.UseNaffka = true - cfg.Database.Account = "file:dendritejs_account.db" - cfg.Database.AppService = "file:dendritejs_appservice.db" - cfg.Database.Device = "file:dendritejs_device.db" - cfg.Database.FederationSender = "file:dendritejs_fedsender.db" - cfg.Database.MediaAPI = "file:dendritejs_mediaapi.db" - cfg.Database.Naffka = "file:dendritejs_naffka.db" - cfg.Database.PublicRoomsAPI = "file:dendritejs_publicrooms.db" - cfg.Database.RoomServer = "file:dendritejs_roomserver.db" - cfg.Database.ServerKey = "file:dendritejs_serverkey.db" - cfg.Database.SyncAPI = "file:dendritejs_syncapi.db" + cfg.Database.Account = "file:/idb/dendritejs_account.db" + cfg.Database.AppService = "file:/idb/dendritejs_appservice.db" + cfg.Database.Device = "file:/idb/dendritejs_device.db" + cfg.Database.FederationSender = "file:/idb/dendritejs_fedsender.db" + cfg.Database.MediaAPI = "file:/idb/dendritejs_mediaapi.db" + cfg.Database.Naffka = "file:/idb/dendritejs_naffka.db" + cfg.Database.PublicRoomsAPI = "file:/idb/dendritejs_publicrooms.db" + cfg.Database.RoomServer = "file:/idb/dendritejs_roomserver.db" + cfg.Database.ServerKey = "file:/idb/dendritejs_serverkey.db" + cfg.Database.SyncAPI = "file:/idb/dendritejs_syncapi.db" cfg.Kafka.Topics.UserUpdates = "user_updates" cfg.Kafka.Topics.OutputTypingEvent = "output_typing_event" + cfg.Kafka.Topics.OutputSendToDeviceEvent = "output_send_to_device_event" cfg.Kafka.Topics.OutputClientData = "output_client_data" cfg.Kafka.Topics.OutputRoomEvent = "output_room_event" cfg.Matrix.TrustedIDServers = []string{ @@ -194,23 +196,24 @@ func main() { accountDB := base.CreateAccountsDB() deviceDB := base.CreateDeviceDB() - keyDB := base.CreateKeyDB() federation := createFederationClient(cfg, node) + + fetcher := &libp2pKeyFetcher{} keyRing := gomatrixserverlib.KeyRing{ KeyFetchers: []gomatrixserverlib.KeyFetcher{ - &libp2pKeyFetcher{}, + fetcher, }, - KeyDatabase: keyDB, + KeyDatabase: fetcher, } - p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node) rsAPI := roomserver.SetupRoomServerComponent(base, keyRing, federation) - eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New()) + eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New(), deviceDB) asQuery := appservice.SetupAppServiceAPIComponent( base, accountDB, deviceDB, federation, rsAPI, transactions.New(), ) fedSenderAPI := federationsender.SetupFederationSenderComponent(base, federation, rsAPI, &keyRing) rsAPI.SetFederationSenderAPI(fedSenderAPI) + p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node, fedSenderAPI) clientapi.SetupClientAPIComponent( base, deviceDB, accountDB, @@ -220,16 +223,20 @@ func main() { eduProducer := producers.NewEDUServerProducer(eduInputAPI) federationapi.SetupFederationAPIComponent(base, accountDB, deviceDB, federation, &keyRing, rsAPI, asQuery, fedSenderAPI, eduProducer) mediaapi.SetupMediaAPIComponent(base, deviceDB) - publicRoomsDB, err := storage.NewPublicRoomsServerDatabase(string(base.Cfg.Database.PublicRoomsAPI)) + publicRoomsDB, err := storage.NewPublicRoomsServerDatabase(string(base.Cfg.Database.PublicRoomsAPI), cfg.Matrix.ServerName) if err != nil { logrus.WithError(err).Panicf("failed to connect to public rooms db") } publicroomsapi.SetupPublicRoomsAPIComponent(base, deviceDB, publicRoomsDB, rsAPI, federation, p2pPublicRoomProvider) syncapi.SetupSyncAPIComponent(base, deviceDB, accountDB, rsAPI, federation, cfg) - httpHandler := internal.WrapHandlerInCORS(base.APIMux) - - http.Handle("/", httpHandler) + internal.SetupHTTPAPI( + http.DefaultServeMux, + base.PublicAPIMux, + base.InternalAPIMux, + cfg, + base.EnableHTTPAPIs, + ) // Expose the matrix APIs via libp2p-js - for federation traffic if node != nil { diff --git a/cmd/dendritejs/publicrooms.go b/cmd/dendritejs/publicrooms.go index 5246cf535..5032bc15f 100644 --- a/cmd/dendritejs/publicrooms.go +++ b/cmd/dendritejs/publicrooms.go @@ -17,23 +17,48 @@ package main import ( + "context" + + "github.com/matrix-org/dendrite/federationsender/api" go_http_js_libp2p "github.com/matrix-org/go-http-js-libp2p" + "github.com/matrix-org/gomatrixserverlib" ) type libp2pPublicRoomsProvider struct { node *go_http_js_libp2p.P2pLocalNode providers []go_http_js_libp2p.PeerInfo + fedSender api.FederationSenderInternalAPI } -func NewLibP2PPublicRoomsProvider(node *go_http_js_libp2p.P2pLocalNode) *libp2pPublicRoomsProvider { +func NewLibP2PPublicRoomsProvider(node *go_http_js_libp2p.P2pLocalNode, fedSender api.FederationSenderInternalAPI) *libp2pPublicRoomsProvider { p := &libp2pPublicRoomsProvider{ - node: node, + node: node, + fedSender: fedSender, } node.RegisterFoundProviders(p.foundProviders) return p } func (p *libp2pPublicRoomsProvider) foundProviders(peerInfos []go_http_js_libp2p.PeerInfo) { + // work out the diff then poke for new ones + seen := make(map[string]bool, len(p.providers)) + for _, pr := range p.providers { + seen[pr.Id] = true + } + var newPeers []gomatrixserverlib.ServerName + for _, pi := range peerInfos { + if !seen[pi.Id] { + newPeers = append(newPeers, gomatrixserverlib.ServerName(pi.Id)) + } + } + if len(newPeers) > 0 { + var res api.PerformServersAliveResponse + // ignore errors, we don't care. + p.fedSender.PerformServersAlive(context.Background(), &api.PerformServersAliveRequest{ + Servers: newPeers, + }, &res) + } + p.providers = peerInfos } diff --git a/cmd/federation-api-proxy/main.go b/cmd/federation-api-proxy/main.go index fa90482d5..7324de148 100644 --- a/cmd/federation-api-proxy/main.go +++ b/cmd/federation-api-proxy/main.go @@ -73,7 +73,6 @@ func makeProxy(targetURL string) (*httputil.ReverseProxy, error) { // Pratically this means that any distinction between '%2F' and '/' // in the URL will be lost by the time it reaches the target. path := req.URL.Path - path = "api" + path log.WithFields(log.Fields{ "path": path, "url": targetURL, diff --git a/dendrite-config.yaml b/dendrite-config.yaml index 2616d74db..a5b295977 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -104,7 +104,8 @@ kafka: topics: output_room_event: roomserverOutput output_client_data: clientapiOutput - output_typing_event: eduServerOutput + output_typing_event: eduServerTypingOutput + output_send_to_device_event: eduServerSendToDeviceOutput user_updates: userUpdates # The postgres connection configs for connecting to the databases e.g a postgres:// URI @@ -138,6 +139,7 @@ listen: appservice_api: "localhost:7777" edu_server: "localhost:7778" key_server: "localhost:7779" + server_key_api: "localhost:7780" # The configuration for tracing the dendrite components. tracing: diff --git a/docs/p2p.md b/docs/p2p.md index 141aaa1fc..d69b47bea 100644 --- a/docs/p2p.md +++ b/docs/p2p.md @@ -1,6 +1,6 @@ ## Peer-to-peer Matrix -These are the instructions for setting up P2P Dendrite, current as of March 2020. There's both Go stuff and JS stuff to do to set this up. +These are the instructions for setting up P2P Dendrite, current as of May 2020. There's both Go stuff and JS stuff to do to set this up. ### Dendrite @@ -28,14 +28,13 @@ Then use `/ip4/127.0.0.1/tcp/9090/ws/p2p-websocket-star/`. ### Riot-web -You need to check out these repos: +You need to check out this repo: ``` $ git clone git@github.com:matrix-org/go-http-js-libp2p.git -$ git clone git@github.com:matrix-org/go-sqlite3-js.git ``` -Make sure to `yarn install` in both of these repos. Then: +Make sure to `yarn install` in the repo. Then: - `$ cp "$(go env GOROOT)/misc/wasm/wasm_exec.js" ./src/vector/` - Comment out the lines in `wasm_exec.js` which contains: @@ -49,7 +48,6 @@ if (!global.fs && global.require) { - Add the following symlinks: they HAVE to be symlinks as the diff in `webpack.config.js` references specific paths. ``` $ cd node_modules -$ ln -s ../../go-sqlite-js # NB: NOT go-sqlite3-js $ ln -s ../../go-http-js-libp2p ``` @@ -65,14 +63,7 @@ You need a Chrome and a Firefox running to test locally as service workers don't Assuming you've `yarn start`ed Riot-Web, go to `http://localhost:8080` and register with `http://localhost:8080` as your HS URL. -You can join rooms by room alias e.g `/join #foo:bar`. - -### Known issues - -- When registering you may be unable to find the server, it'll seem flakey. This happens because the SW, particularly in Firefox, - gets killed after 30s of inactivity. When you are not registered, you aren't doing `/sync` calls to keep the SW alive, so if you - don't register for a while and idle on the page, the HS will disappear. To fix, unregister the SW, and then refresh the page. - -- The libp2p layer has rate limits, so frequent Federation traffic may cause the connection to drop and messages to not be transferred. - I guess in other words, don't send too much traffic? - +You can: + - join rooms by room alias e.g `/join #foo:bar`. + - invite specific users to a room. + - explore the published room list. All members of the room can re-publish aliases (unlike Synapse). diff --git a/eduserver/api/input.go b/eduserver/api/input.go index 999a5b41a..fa7f30cb6 100644 --- a/eduserver/api/input.go +++ b/eduserver/api/input.go @@ -1,3 +1,7 @@ +// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -37,6 +41,12 @@ type InputTypingEvent struct { OriginServerTS gomatrixserverlib.Timestamp `json:"origin_server_ts"` } +type InputSendToDeviceEvent struct { + UserID string `json:"user_id"` + DeviceID string `json:"device_id"` + gomatrixserverlib.SendToDeviceEvent +} + // InputTypingEventRequest is a request to EDUServerInputAPI type InputTypingEventRequest struct { InputTypingEvent InputTypingEvent `json:"input_typing_event"` @@ -45,6 +55,14 @@ type InputTypingEventRequest struct { // InputTypingEventResponse is a response to InputTypingEvents type InputTypingEventResponse struct{} +// InputSendToDeviceEventRequest is a request to EDUServerInputAPI +type InputSendToDeviceEventRequest struct { + InputSendToDeviceEvent InputSendToDeviceEvent `json:"input_send_to_device_event"` +} + +// InputSendToDeviceEventResponse is a response to InputSendToDeviceEventRequest +type InputSendToDeviceEventResponse struct{} + // EDUServerInputAPI is used to write events to the typing server. type EDUServerInputAPI interface { InputTypingEvent( @@ -52,10 +70,19 @@ type EDUServerInputAPI interface { request *InputTypingEventRequest, response *InputTypingEventResponse, ) error + + InputSendToDeviceEvent( + ctx context.Context, + request *InputSendToDeviceEventRequest, + response *InputSendToDeviceEventResponse, + ) error } // EDUServerInputTypingEventPath is the HTTP path for the InputTypingEvent API. -const EDUServerInputTypingEventPath = "/api/eduserver/input" +const EDUServerInputTypingEventPath = "/eduserver/input" + +// EDUServerInputSendToDeviceEventPath is the HTTP path for the InputSendToDeviceEvent API. +const EDUServerInputSendToDeviceEventPath = "/eduserver/sendToDevice" // NewEDUServerInputAPIHTTP creates a EDUServerInputAPI implemented by talking to a HTTP POST API. func NewEDUServerInputAPIHTTP(eduServerURL string, httpClient *http.Client) (EDUServerInputAPI, error) { @@ -70,7 +97,7 @@ type httpEDUServerInputAPI struct { httpClient *http.Client } -// InputRoomEvents implements EDUServerInputAPI +// InputTypingEvent implements EDUServerInputAPI func (h *httpEDUServerInputAPI) InputTypingEvent( ctx context.Context, request *InputTypingEventRequest, @@ -82,3 +109,16 @@ func (h *httpEDUServerInputAPI) InputTypingEvent( apiURL := h.eduServerURL + EDUServerInputTypingEventPath return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } + +// InputSendToDeviceEvent implements EDUServerInputAPI +func (h *httpEDUServerInputAPI) InputSendToDeviceEvent( + ctx context.Context, + request *InputSendToDeviceEventRequest, + response *InputSendToDeviceEventResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "InputSendToDeviceEvent") + defer span.Finish() + + apiURL := h.eduServerURL + EDUServerInputSendToDeviceEventPath + return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} diff --git a/eduserver/api/output.go b/eduserver/api/output.go index 8696acf49..e6ded8413 100644 --- a/eduserver/api/output.go +++ b/eduserver/api/output.go @@ -1,3 +1,7 @@ +// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -12,7 +16,11 @@ package api -import "time" +import ( + "time" + + "github.com/matrix-org/gomatrixserverlib" +) // OutputTypingEvent is an entry in typing server output kafka log. // This contains the event with extra fields used to create 'm.typing' event @@ -32,3 +40,12 @@ type TypingEvent struct { UserID string `json:"user_id"` Typing bool `json:"typing"` } + +// OutputSendToDeviceEvent is an entry in the send-to-device output kafka log. +// This contains the full event content, along with the user ID and device ID +// to which it is destined. +type OutputSendToDeviceEvent struct { + UserID string `json:"user_id"` + DeviceID string `json:"device_id"` + gomatrixserverlib.SendToDeviceEvent +} diff --git a/eduserver/cache/cache.go b/eduserver/cache/cache.go index 46f7a2b13..dd535a6d2 100644 --- a/eduserver/cache/cache.go +++ b/eduserver/cache/cache.go @@ -1,3 +1,7 @@ +// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -109,6 +113,19 @@ func (t *EDUCache) AddTypingUser( return t.GetLatestSyncPosition() } +// AddSendToDeviceMessage increases the sync position for +// send-to-device updates. +// Returns the sync position before update, as the caller +// will use this to record the current stream position +// at the time that the send-to-device message was sent. +func (t *EDUCache) AddSendToDeviceMessage() int64 { + t.Lock() + defer t.Unlock() + latestSyncPosition := t.latestSyncPosition + t.latestSyncPosition++ + return latestSyncPosition +} + // addUser with mutex lock & replace the previous timer. // Returns the latest typing sync position after update. func (t *EDUCache) addUser( diff --git a/eduserver/cache/cache_test.go b/eduserver/cache/cache_test.go index d1b2f8bd5..c7d01879f 100644 --- a/eduserver/cache/cache_test.go +++ b/eduserver/cache/cache_test.go @@ -1,3 +1,7 @@ +// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at diff --git a/eduserver/eduserver.go b/eduserver/eduserver.go index cba60b8ea..6f664eb67 100644 --- a/eduserver/eduserver.go +++ b/eduserver/eduserver.go @@ -1,3 +1,7 @@ +// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -13,8 +17,7 @@ package eduserver import ( - "net/http" - + "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/eduserver/input" @@ -28,16 +31,18 @@ import ( func SetupEDUServerComponent( base *basecomponent.BaseDendrite, eduCache *cache.EDUCache, + deviceDB devices.Database, ) api.EDUServerInputAPI { inputAPI := &input.EDUServerInputAPI{ - Cache: eduCache, - Producer: base.KafkaProducer, - OutputTypingEventTopic: string(base.Cfg.Kafka.Topics.OutputTypingEvent), + Cache: eduCache, + DeviceDB: deviceDB, + Producer: base.KafkaProducer, + OutputTypingEventTopic: string(base.Cfg.Kafka.Topics.OutputTypingEvent), + OutputSendToDeviceEventTopic: string(base.Cfg.Kafka.Topics.OutputSendToDeviceEvent), + ServerName: base.Cfg.Matrix.ServerName, } - if base.EnableHTTPAPIs { - inputAPI.SetupHTTP(http.DefaultServeMux) - } + inputAPI.SetupHTTP(base.InternalAPIMux) return inputAPI } diff --git a/eduserver/input/input.go b/eduserver/input/input.go index 50837154a..4e3051950 100644 --- a/eduserver/input/input.go +++ b/eduserver/input/input.go @@ -1,3 +1,7 @@ +// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -19,11 +23,14 @@ import ( "time" "github.com/Shopify/sarama" + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/sirupsen/logrus" ) // EDUServerInputAPI implements api.EDUServerInputAPI @@ -32,8 +39,14 @@ type EDUServerInputAPI struct { Cache *cache.EDUCache // The kafka topic to output new typing events to. OutputTypingEventTopic string + // The kafka topic to output new send to device events to. + OutputSendToDeviceEventTopic string // kafka producer Producer sarama.SyncProducer + // device database + DeviceDB devices.Database + // our server name + ServerName gomatrixserverlib.ServerName } // InputTypingEvent implements api.EDUServerInputAPI @@ -53,10 +66,20 @@ func (t *EDUServerInputAPI) InputTypingEvent( t.Cache.RemoveUser(ite.UserID, ite.RoomID) } - return t.sendEvent(ite) + return t.sendTypingEvent(ite) } -func (t *EDUServerInputAPI) sendEvent(ite *api.InputTypingEvent) error { +// InputTypingEvent implements api.EDUServerInputAPI +func (t *EDUServerInputAPI) InputSendToDeviceEvent( + ctx context.Context, + request *api.InputSendToDeviceEventRequest, + response *api.InputSendToDeviceEventResponse, +) error { + ise := &request.InputSendToDeviceEvent + return t.sendToDeviceEvent(ise) +} + +func (t *EDUServerInputAPI) sendTypingEvent(ite *api.InputTypingEvent) error { ev := &api.TypingEvent{ Type: gomatrixserverlib.MTyping, RoomID: ite.RoomID, @@ -89,9 +112,68 @@ func (t *EDUServerInputAPI) sendEvent(ite *api.InputTypingEvent) error { return err } +func (t *EDUServerInputAPI) sendToDeviceEvent(ise *api.InputSendToDeviceEvent) error { + devices := []string{} + localpart, domain, err := gomatrixserverlib.SplitID('@', ise.UserID) + if err != nil { + return err + } + + // If the event is targeted locally then we want to expand the wildcard + // out into individual device IDs so that we can send them to each respective + // device. If the event isn't targeted locally then we can't expand the + // wildcard as we don't know about the remote devices, so instead we leave it + // as-is, so that the federation sender can send it on with the wildcard intact. + if domain == t.ServerName && ise.DeviceID == "*" { + devs, err := t.DeviceDB.GetDevicesByLocalpart(context.TODO(), localpart) + if err != nil { + return err + } + for _, dev := range devs { + devices = append(devices, dev.ID) + } + } else { + devices = append(devices, ise.DeviceID) + } + + for _, device := range devices { + ote := &api.OutputSendToDeviceEvent{ + UserID: ise.UserID, + DeviceID: device, + SendToDeviceEvent: ise.SendToDeviceEvent, + } + + logrus.WithFields(logrus.Fields{ + "user_id": ise.UserID, + "device_id": ise.DeviceID, + "event_type": ise.Type, + }).Info("handling send-to-device message") + + eventJSON, err := json.Marshal(ote) + if err != nil { + logrus.WithError(err).Error("sendToDevice failed json.Marshal") + return err + } + + m := &sarama.ProducerMessage{ + Topic: string(t.OutputSendToDeviceEventTopic), + Key: sarama.StringEncoder(ote.UserID), + Value: sarama.ByteEncoder(eventJSON), + } + + _, _, err = t.Producer.SendMessage(m) + if err != nil { + logrus.WithError(err).Error("sendToDevice failed t.Producer.SendMessage") + return err + } + } + + return nil +} + // SetupHTTP adds the EDUServerInputAPI handlers to the http.ServeMux. -func (t *EDUServerInputAPI) SetupHTTP(servMux *http.ServeMux) { - servMux.Handle(api.EDUServerInputTypingEventPath, +func (t *EDUServerInputAPI) SetupHTTP(internalAPIMux *mux.Router) { + internalAPIMux.Handle(api.EDUServerInputTypingEventPath, internal.MakeInternalAPI("inputTypingEvents", func(req *http.Request) util.JSONResponse { var request api.InputTypingEventRequest var response api.InputTypingEventResponse @@ -104,4 +186,17 @@ func (t *EDUServerInputAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(api.EDUServerInputSendToDeviceEventPath, + internal.MakeInternalAPI("inputSendToDeviceEvents", func(req *http.Request) util.JSONResponse { + var request api.InputSendToDeviceEventRequest + var response api.InputSendToDeviceEventResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := t.InputSendToDeviceEvent(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 9e3411403..baeaa36bc 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -44,7 +44,7 @@ func SetupFederationAPIComponent( roomserverProducer := producers.NewRoomserverProducer(rsAPI) routing.Setup( - base.APIMux, base.Cfg, rsAPI, asAPI, roomserverProducer, + base.PublicAPIMux, base.Cfg, rsAPI, asAPI, roomserverProducer, eduProducer, federationSenderAPI, *keyRing, federation, accountsDB, deviceDB, ) diff --git a/federationapi/routing/devices.go b/federationapi/routing/devices.go index 49686c761..caf5fe592 100644 --- a/federationapi/routing/devices.go +++ b/federationapi/routing/devices.go @@ -15,19 +15,13 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) -type userDevicesResponse struct { - UserID string `json:"user_id"` - StreamID int `json:"stream_id"` - Devices []authtypes.Device `json:"devices"` -} - // GetUserDevices for the given user id func GetUserDevices( req *http.Request, @@ -42,20 +36,30 @@ func GetUserDevices( } } + response := gomatrixserverlib.RespUserDevices{ + UserID: userID, + // TODO: we should return an incrementing stream ID each time the device + // list changes for delta changes to be recognised + StreamID: 0, + } + devs, err := deviceDB.GetDevicesByLocalpart(req.Context(), localpart) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDevicesByLocalPart failed") return jsonerror.InternalServerError() } + for _, dev := range devs { + device := gomatrixserverlib.RespUserDevice{ + DeviceID: dev.ID, + DisplayName: dev.DisplayName, + Keys: []gomatrixserverlib.RespUserDeviceKeys{}, + } + response.Devices = append(response.Devices, device) + } + return util.JSONResponse{ Code: 200, - // TODO: we should return an incrementing stream ID each time the device - // list changes for delta changes to be recognised - JSON: userDevicesResponse{ - UserID: userID, - StreamID: 0, - Devices: devs, - }, + JSON: response, } } diff --git a/federationapi/routing/missingevents.go b/federationapi/routing/missingevents.go index ae91c5891..f93e0eb41 100644 --- a/federationapi/routing/missingevents.go +++ b/federationapi/routing/missingevents.go @@ -60,9 +60,14 @@ func GetMissingEvents( } eventsResponse.Events = filterEvents(eventsResponse.Events, gme.MinDepth, roomID) + + resp := gomatrixserverlib.RespMissingEvents{ + Events: gomatrixserverlib.UnwrapEventHeaders(eventsResponse.Events), + } + return util.JSONResponse{ Code: http.StatusOK, - JSON: eventsResponse, + JSON: resp, } } diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index c368c16c4..b75a7941d 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -31,9 +31,9 @@ import ( ) const ( - pathPrefixV2Keys = "/_matrix/key/v2" - pathPrefixV1Federation = "/_matrix/federation/v1" - pathPrefixV2Federation = "/_matrix/federation/v2" + pathPrefixV2Keys = "/key/v2" + pathPrefixV1Federation = "/federation/v1" + pathPrefixV2Federation = "/federation/v2" ) // Setup registers HTTP handlers with the given ServeMux. @@ -42,21 +42,25 @@ const ( // applied: // nolint: gocyclo func Setup( - apiMux *mux.Router, + publicAPIMux *mux.Router, cfg *config.Dendrite, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, producer *producers.RoomserverProducer, eduProducer *producers.EDUServerProducer, - federationSenderAPI federationSenderAPI.FederationSenderInternalAPI, + fsAPI federationSenderAPI.FederationSenderInternalAPI, keys gomatrixserverlib.KeyRing, federation *gomatrixserverlib.FederationClient, accountDB accounts.Database, deviceDB devices.Database, ) { - v2keysmux := apiMux.PathPrefix(pathPrefixV2Keys).Subrouter() - v1fedmux := apiMux.PathPrefix(pathPrefixV1Federation).Subrouter() - v2fedmux := apiMux.PathPrefix(pathPrefixV2Federation).Subrouter() + v2keysmux := publicAPIMux.PathPrefix(pathPrefixV2Keys).Subrouter() + v1fedmux := publicAPIMux.PathPrefix(pathPrefixV1Federation).Subrouter() + v2fedmux := publicAPIMux.PathPrefix(pathPrefixV2Federation).Subrouter() + + wakeup := &internal.FederationWakeups{ + FsAPI: fsAPI, + } localKeys := internal.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse { return LocalKeys(cfg) @@ -71,7 +75,7 @@ func Setup( v2keysmux.Handle("/server", localKeys).Methods(http.MethodGet) v1fedmux.Handle("/send/{txnID}", internal.MakeFedAPI( - "federation_send", cfg.Matrix.ServerName, keys, + "federation_send", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { vars, err := internal.URLDecodeMapValues(mux.Vars(httpReq)) if err != nil { @@ -85,7 +89,7 @@ func Setup( )).Methods(http.MethodPut, http.MethodOptions) v2fedmux.Handle("/invite/{roomID}/{eventID}", internal.MakeFedAPI( - "federation_invite", cfg.Matrix.ServerName, keys, + "federation_invite", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { vars, err := internal.URLDecodeMapValues(mux.Vars(httpReq)) if err != nil { @@ -105,7 +109,7 @@ func Setup( )).Methods(http.MethodPost, http.MethodOptions) v1fedmux.Handle("/exchange_third_party_invite/{roomID}", internal.MakeFedAPI( - "exchange_third_party_invite", cfg.Matrix.ServerName, keys, + "exchange_third_party_invite", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { vars, err := internal.URLDecodeMapValues(mux.Vars(httpReq)) if err != nil { @@ -118,7 +122,7 @@ func Setup( )).Methods(http.MethodPut, http.MethodOptions) v1fedmux.Handle("/event/{eventID}", internal.MakeFedAPI( - "federation_get_event", cfg.Matrix.ServerName, keys, + "federation_get_event", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { vars, err := internal.URLDecodeMapValues(mux.Vars(httpReq)) if err != nil { @@ -131,7 +135,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/state/{roomID}", internal.MakeFedAPI( - "federation_get_state", cfg.Matrix.ServerName, keys, + "federation_get_state", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { vars, err := internal.URLDecodeMapValues(mux.Vars(httpReq)) if err != nil { @@ -144,7 +148,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/state_ids/{roomID}", internal.MakeFedAPI( - "federation_get_state_ids", cfg.Matrix.ServerName, keys, + "federation_get_state_ids", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { vars, err := internal.URLDecodeMapValues(mux.Vars(httpReq)) if err != nil { @@ -157,7 +161,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/event_auth/{roomID}/{eventID}", internal.MakeFedAPI( - "federation_get_event_auth", cfg.Matrix.ServerName, keys, + "federation_get_event_auth", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { vars := mux.Vars(httpReq) return GetEventAuth( @@ -167,16 +171,16 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/query/directory", internal.MakeFedAPI( - "federation_query_room_alias", cfg.Matrix.ServerName, keys, + "federation_query_room_alias", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { return RoomAliasToID( - httpReq, federation, cfg, rsAPI, federationSenderAPI, + httpReq, federation, cfg, rsAPI, fsAPI, ) }, )).Methods(http.MethodGet) v1fedmux.Handle("/query/profile", internal.MakeFedAPI( - "federation_query_profile", cfg.Matrix.ServerName, keys, + "federation_query_profile", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { return GetProfile( httpReq, accountDB, cfg, asAPI, @@ -185,7 +189,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/user/devices/{userID}", internal.MakeFedAPI( - "federation_user_devices", cfg.Matrix.ServerName, keys, + "federation_user_devices", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { vars, err := internal.URLDecodeMapValues(mux.Vars(httpReq)) if err != nil { @@ -198,7 +202,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/make_join/{roomID}/{eventID}", internal.MakeFedAPI( - "federation_make_join", cfg.Matrix.ServerName, keys, + "federation_make_join", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { vars, err := internal.URLDecodeMapValues(mux.Vars(httpReq)) if err != nil { @@ -227,7 +231,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/send_join/{roomID}/{eventID}", internal.MakeFedAPI( - "federation_send_join", cfg.Matrix.ServerName, keys, + "federation_send_join", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { vars, err := internal.URLDecodeMapValues(mux.Vars(httpReq)) if err != nil { @@ -249,7 +253,7 @@ func Setup( )).Methods(http.MethodPut) v2fedmux.Handle("/send_join/{roomID}/{eventID}", internal.MakeFedAPI( - "federation_send_join", cfg.Matrix.ServerName, keys, + "federation_send_join", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { vars, err := internal.URLDecodeMapValues(mux.Vars(httpReq)) if err != nil { @@ -264,7 +268,7 @@ func Setup( )).Methods(http.MethodPut) v1fedmux.Handle("/make_leave/{roomID}/{eventID}", internal.MakeFedAPI( - "federation_make_leave", cfg.Matrix.ServerName, keys, + "federation_make_leave", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { vars, err := internal.URLDecodeMapValues(mux.Vars(httpReq)) if err != nil { @@ -279,7 +283,7 @@ func Setup( )).Methods(http.MethodGet) v2fedmux.Handle("/send_leave/{roomID}/{eventID}", internal.MakeFedAPI( - "federation_send_leave", cfg.Matrix.ServerName, keys, + "federation_send_leave", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { vars, err := internal.URLDecodeMapValues(mux.Vars(httpReq)) if err != nil { @@ -301,7 +305,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/get_missing_events/{roomID}", internal.MakeFedAPI( - "federation_get_missing_events", cfg.Matrix.ServerName, keys, + "federation_get_missing_events", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { vars, err := internal.URLDecodeMapValues(mux.Vars(httpReq)) if err != nil { @@ -312,7 +316,7 @@ func Setup( )).Methods(http.MethodPost) v1fedmux.Handle("/backfill/{roomID}", internal.MakeFedAPI( - "federation_backfill", cfg.Matrix.ServerName, keys, + "federation_backfill", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { vars, err := internal.URLDecodeMapValues(mux.Vars(httpReq)) if err != nil { diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 6cf44018e..74b4c0144 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -76,6 +76,7 @@ func Send( resp, err := t.processTransaction() if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("t.processTransaction failed") + return util.ErrorResponse(err) } // https://matrix.org/docs/spec/server_server/r0.1.3#put-matrix-federation-v1-send-txnid @@ -117,7 +118,7 @@ type txnFederationClient interface { func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, error) { results := make(map[string]gomatrixserverlib.PDUResult) - var pdus []gomatrixserverlib.HeaderedEvent + pdus := []gomatrixserverlib.HeaderedEvent{} for _, pdu := range t.PDUs { var header struct { RoomID string `json:"room_id"` @@ -264,6 +265,25 @@ func (t *txnReq) processEDUs(edus []gomatrixserverlib.EDU) { if err := t.eduProducer.SendTyping(t.context, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil { util.GetLogger(t.context).WithError(err).Error("Failed to send typing event to edu server") } + case gomatrixserverlib.MDirectToDevice: + // https://matrix.org/docs/spec/server_server/r0.1.3#m-direct-to-device-schema + var directPayload gomatrixserverlib.ToDeviceMessage + if err := json.Unmarshal(e.Content, &directPayload); err != nil { + util.GetLogger(t.context).WithError(err).Error("Failed to unmarshal send-to-device events") + continue + } + for userID, byUser := range directPayload.Messages { + for deviceID, message := range byUser { + // TODO: check that the user and the device actually exist here + if err := t.eduProducer.SendToDevice(t.context, directPayload.Sender, userID, deviceID, directPayload.Type, message); err != nil { + util.GetLogger(t.context).WithError(err).WithFields(logrus.Fields{ + "sender": directPayload.Sender, + "user_id": userID, + "device_id": deviceID, + }).Error("Failed to send send-to-device event to edu server") + } + } + } default: util.GetLogger(t.context).WithField("type", e.Type).Warn("unhandled edu") } diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index cb8aec6f5..3e28a3476 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -77,6 +77,14 @@ func (p *testEDUProducer) InputTypingEvent( return nil } +func (p *testEDUProducer) InputSendToDeviceEvent( + ctx context.Context, + request *eduAPI.InputSendToDeviceEventRequest, + response *eduAPI.InputSendToDeviceEventResponse, +) error { + return nil +} + type testRoomserverAPI struct { inputRoomEvents []api.InputRoomEvent queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse diff --git a/federationsender/api/api.go b/federationsender/api/api.go index 678f02e68..4eb20cb67 100644 --- a/federationsender/api/api.go +++ b/federationsender/api/api.go @@ -42,6 +42,12 @@ type FederationSenderInternalAPI interface { request *PerformLeaveRequest, response *PerformLeaveResponse, ) error + // Notifies the federation sender that these servers may be online and to retry sending messages. + PerformServersAlive( + ctx context.Context, + request *PerformServersAliveRequest, + response *PerformServersAliveResponse, + ) error } // NewFederationSenderInternalAPIHTTP creates a FederationSenderInternalAPI implemented by talking to a HTTP POST API. diff --git a/federationsender/api/perform.go b/federationsender/api/perform.go index 427e70683..5e4d7fe97 100644 --- a/federationsender/api/perform.go +++ b/federationsender/api/perform.go @@ -11,13 +11,16 @@ import ( const ( // FederationSenderPerformJoinRequestPath is the HTTP path for the PerformJoinRequest API. - FederationSenderPerformDirectoryLookupRequestPath = "/api/federationsender/performDirectoryLookup" + FederationSenderPerformDirectoryLookupRequestPath = "/federationsender/performDirectoryLookup" // FederationSenderPerformJoinRequestPath is the HTTP path for the PerformJoinRequest API. - FederationSenderPerformJoinRequestPath = "/api/federationsender/performJoinRequest" + FederationSenderPerformJoinRequestPath = "/federationsender/performJoinRequest" // FederationSenderPerformLeaveRequestPath is the HTTP path for the PerformLeaveRequest API. - FederationSenderPerformLeaveRequestPath = "/api/federationsender/performLeaveRequest" + FederationSenderPerformLeaveRequestPath = "/federationsender/performLeaveRequest" + + // FederationSenderPerformServersAlivePath is the HTTP path for the PerformServersAlive API. + FederationSenderPerformServersAlivePath = "/federationsender/performServersAlive" ) type PerformDirectoryLookupRequest struct { @@ -44,8 +47,9 @@ func (h *httpFederationSenderInternalAPI) PerformDirectoryLookup( } type PerformJoinRequest struct { - RoomID string `json:"room_id"` - UserID string `json:"user_id"` + RoomID string `json:"room_id"` + UserID string `json:"user_id"` + // The sorted list of servers to try. Servers will be tried sequentially, after de-duplication. ServerNames types.ServerNames `json:"server_names"` Content map[string]interface{} `json:"content"` } @@ -87,3 +91,22 @@ func (h *httpFederationSenderInternalAPI) PerformLeave( apiURL := h.federationSenderURL + FederationSenderPerformLeaveRequestPath return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } + +type PerformServersAliveRequest struct { + Servers []gomatrixserverlib.ServerName +} + +type PerformServersAliveResponse struct { +} + +func (h *httpFederationSenderInternalAPI) PerformServersAlive( + ctx context.Context, + request *PerformServersAliveRequest, + response *PerformServersAliveResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformServersAlive") + defer span.Finish() + + apiURL := h.federationSenderURL + FederationSenderPerformServersAlivePath + return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} diff --git a/federationsender/api/query.go b/federationsender/api/query.go index bf58d5cc9..4c0f757b2 100644 --- a/federationsender/api/query.go +++ b/federationsender/api/query.go @@ -11,10 +11,10 @@ import ( ) // FederationSenderQueryJoinedHostsInRoomPath is the HTTP path for the QueryJoinedHostsInRoom API. -const FederationSenderQueryJoinedHostsInRoomPath = "/api/federationsender/queryJoinedHostsInRoom" +const FederationSenderQueryJoinedHostsInRoomPath = "/federationsender/queryJoinedHostsInRoom" // FederationSenderQueryJoinedHostServerNamesInRoomPath is the HTTP path for the QueryJoinedHostServerNamesInRoom API. -const FederationSenderQueryJoinedHostServerNamesInRoomPath = "/api/federationsender/queryJoinedHostServerNamesInRoom" +const FederationSenderQueryJoinedHostServerNamesInRoomPath = "/federationsender/queryJoinedHostServerNamesInRoom" // QueryJoinedHostsInRoomRequest is a request to QueryJoinedHostsInRoom type QueryJoinedHostsInRoomRequest struct { diff --git a/federationsender/federationsender.go b/federationsender/federationsender.go index 446f920c2..9e5cc8dd1 100644 --- a/federationsender/federationsender.go +++ b/federationsender/federationsender.go @@ -15,8 +15,6 @@ package federationsender import ( - "net/http" - "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/federationsender/consumers" "github.com/matrix-org/dendrite/federationsender/internal" @@ -69,12 +67,10 @@ func SetupFederationSenderComponent( queryAPI := internal.NewFederationSenderInternalAPI( federationSenderDB, base.Cfg, roomserverProducer, federation, keyRing, - statistics, + statistics, queues, ) - if base.EnableHTTPAPIs { - queryAPI.SetupHTTP(http.DefaultServeMux) - } + queryAPI.SetupHTTP(base.InternalAPIMux) return queryAPI } diff --git a/federationsender/internal/api.go b/federationsender/internal/api.go index 1f6d857f4..edf8fb4e9 100644 --- a/federationsender/internal/api.go +++ b/federationsender/internal/api.go @@ -4,8 +4,10 @@ import ( "encoding/json" "net/http" + "github.com/gorilla/mux" "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/federationsender/producers" + "github.com/matrix-org/dendrite/federationsender/queue" "github.com/matrix-org/dendrite/federationsender/storage" "github.com/matrix-org/dendrite/federationsender/types" "github.com/matrix-org/dendrite/internal" @@ -23,6 +25,7 @@ type FederationSenderInternalAPI struct { producer *producers.RoomserverProducer federation *gomatrixserverlib.FederationClient keyRing *gomatrixserverlib.KeyRing + queues *queue.OutgoingQueues } func NewFederationSenderInternalAPI( @@ -31,6 +34,7 @@ func NewFederationSenderInternalAPI( federation *gomatrixserverlib.FederationClient, keyRing *gomatrixserverlib.KeyRing, statistics *types.Statistics, + queues *queue.OutgoingQueues, ) *FederationSenderInternalAPI { return &FederationSenderInternalAPI{ db: db, @@ -39,12 +43,13 @@ func NewFederationSenderInternalAPI( federation: federation, keyRing: keyRing, statistics: statistics, + queues: queues, } } // SetupHTTP adds the FederationSenderInternalAPI handlers to the http.ServeMux. -func (f *FederationSenderInternalAPI) SetupHTTP(servMux *http.ServeMux) { - servMux.Handle( +func (f *FederationSenderInternalAPI) SetupHTTP(internalAPIMux *mux.Router) { + internalAPIMux.Handle( api.FederationSenderQueryJoinedHostsInRoomPath, internal.MakeInternalAPI("QueryJoinedHostsInRoom", func(req *http.Request) util.JSONResponse { var request api.QueryJoinedHostsInRoomRequest @@ -58,7 +63,7 @@ func (f *FederationSenderInternalAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - servMux.Handle( + internalAPIMux.Handle( api.FederationSenderQueryJoinedHostServerNamesInRoomPath, internal.MakeInternalAPI("QueryJoinedHostServerNamesInRoom", func(req *http.Request) util.JSONResponse { var request api.QueryJoinedHostServerNamesInRoomRequest @@ -72,7 +77,7 @@ func (f *FederationSenderInternalAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - servMux.Handle(api.FederationSenderPerformJoinRequestPath, + internalAPIMux.Handle(api.FederationSenderPerformJoinRequestPath, internal.MakeInternalAPI("PerformJoinRequest", func(req *http.Request) util.JSONResponse { var request api.PerformJoinRequest var response api.PerformJoinResponse @@ -85,7 +90,7 @@ func (f *FederationSenderInternalAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - servMux.Handle(api.FederationSenderPerformLeaveRequestPath, + internalAPIMux.Handle(api.FederationSenderPerformLeaveRequestPath, internal.MakeInternalAPI("PerformLeaveRequest", func(req *http.Request) util.JSONResponse { var request api.PerformLeaveRequest var response api.PerformLeaveResponse @@ -98,4 +103,30 @@ func (f *FederationSenderInternalAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(api.FederationSenderPerformDirectoryLookupRequestPath, + internal.MakeInternalAPI("PerformDirectoryLookupRequest", func(req *http.Request) util.JSONResponse { + var request api.PerformDirectoryLookupRequest + var response api.PerformDirectoryLookupResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := f.PerformDirectoryLookup(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(api.FederationSenderPerformServersAlivePath, + internal.MakeInternalAPI("PerformServersAliveRequest", func(req *http.Request) util.JSONResponse { + var request api.PerformServersAliveRequest + var response api.PerformServersAliveResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := f.PerformServersAlive(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/federationsender/internal/perform.go b/federationsender/internal/perform.go index 9791afefa..c601e9604 100644 --- a/federationsender/internal/perform.go +++ b/federationsender/internal/perform.go @@ -46,8 +46,19 @@ func (r *FederationSenderInternalAPI) PerformJoin( supportedVersions = append(supportedVersions, version) } - // Deduplicate the server names we were provided. - util.SortAndUnique(request.ServerNames) + // Deduplicate the server names we were provided but keep the ordering + // as this encodes useful information about which servers are most likely + // to respond. + seenSet := make(map[gomatrixserverlib.ServerName]bool) + var uniqueList []gomatrixserverlib.ServerName + for _, srv := range request.ServerNames { + if seenSet[srv] { + continue + } + seenSet[srv] = true + uniqueList = append(uniqueList, srv) + } + request.ServerNames = uniqueList // Try each server that we were provided until we land on one that // successfully completes the make-join send-join dance. @@ -264,3 +275,16 @@ func (r *FederationSenderInternalAPI) PerformLeave( request.RoomID, len(request.ServerNames), ) } + +// PerformServersAlive implements api.FederationSenderInternalAPI +func (r *FederationSenderInternalAPI) PerformServersAlive( + ctx context.Context, + request *api.PerformServersAliveRequest, + response *api.PerformServersAliveResponse, +) (err error) { + for _, srv := range request.Servers { + r.queues.RetryServer(srv) + } + + return nil +} diff --git a/federationsender/queue/destinationqueue.go b/federationsender/queue/destinationqueue.go index 45faa287c..4ab610de7 100644 --- a/federationsender/queue/destinationqueue.go +++ b/federationsender/queue/destinationqueue.go @@ -39,6 +39,7 @@ type destinationQueue struct { origin gomatrixserverlib.ServerName // origin of requests destination gomatrixserverlib.ServerName // destination of requests running atomic.Bool // is the queue worker running? + backingOff atomic.Bool // true if we're backing off statistics *types.ServerStatistics // statistics about this remote server incomingPDUs chan *gomatrixserverlib.HeaderedEvent // PDUs to send incomingEDUs chan *gomatrixserverlib.EDU // EDUs to send @@ -47,6 +48,28 @@ type destinationQueue struct { pendingPDUs []*gomatrixserverlib.HeaderedEvent // owned by backgroundSend pendingEDUs []*gomatrixserverlib.EDU // owned by backgroundSend pendingInvites []*gomatrixserverlib.InviteV2Request // owned by backgroundSend + retryServerCh chan bool // interrupts backoff +} + +// retry will clear the blacklist state and attempt to send built up events to the server, +// resetting and interrupting any backoff timers. +func (oq *destinationQueue) retry() { + // TODO: We don't send all events in the case where the server has been blacklisted as we + // drop events instead then. This means we will send the oldest N events (chan size, currently 128) + // and then skip ahead a lot which feels non-ideal but equally we can't persist thousands of events + // in-memory to maybe-send it one day. Ideally we would just shove these pending events in a database + // so we can send a lot of events. + oq.statistics.Success() + // if we were backing off, swap to not backing off and interrupt the select. + // We need to use an atomic bool here to prevent multiple calls to retry() blocking on the channel + // as it is unbuffered. + if oq.backingOff.CAS(true, false) { + oq.retryServerCh <- true + } + if !oq.running.Load() { + log.Infof("Restarting queue for %s", oq.destination) + go oq.backgroundSend() + } } // Send event adds the event to the pending queue for the destination. @@ -110,12 +133,26 @@ func (oq *destinationQueue) backgroundSend() { // of the queue and they will all be added to transactions // in order. oq.pendingPDUs = append(oq.pendingPDUs, pdu) + // If there are any more things waiting in the channel queue + // then read them. This is safe because we guarantee only + // having one goroutine per destination queue, so the channel + // isn't being consumed anywhere else. + for len(oq.incomingPDUs) > 0 { + oq.pendingPDUs = append(oq.pendingPDUs, <-oq.incomingPDUs) + } case edu := <-oq.incomingEDUs: // Likewise for EDUs, although we should probably not try // too hard with some EDUs (like typing notifications) after // a certain amount of time has passed. // TODO: think about EDU expiry some more oq.pendingEDUs = append(oq.pendingEDUs, edu) + // If there are any more things waiting in the channel queue + // then read them. This is safe because we guarantee only + // having one goroutine per destination queue, so the channel + // isn't being consumed anywhere else. + for len(oq.incomingEDUs) > 0 { + oq.pendingEDUs = append(oq.pendingEDUs, <-oq.incomingEDUs) + } case invite := <-oq.incomingInvites: // There's no strict ordering requirement for invites like // there is for transactions, so we put the invite onto the @@ -126,6 +163,13 @@ func (oq *destinationQueue) backgroundSend() { []*gomatrixserverlib.InviteV2Request{invite}, oq.pendingInvites..., ) + // If there are any more things waiting in the channel queue + // then read them. This is safe because we guarantee only + // having one goroutine per destination queue, so the channel + // isn't being consumed anywhere else. + for len(oq.incomingInvites) > 0 { + oq.pendingInvites = append(oq.pendingInvites, <-oq.incomingInvites) + } case <-time.After(time.Second * 30): // The worker is idle so stop the goroutine. It'll // get restarted automatically the next time we @@ -134,9 +178,15 @@ func (oq *destinationQueue) backgroundSend() { } // If we are backing off this server then wait for the - // backoff duration to complete first. + // backoff duration to complete first, or until explicitly + // told to retry. if backoff, duration := oq.statistics.BackoffDuration(); backoff { - <-time.After(duration) + oq.backingOff.Store(true) + select { + case <-time.After(duration): + case <-oq.retryServerCh: + } + oq.backingOff.Store(false) } // How many things do we have waiting? diff --git a/federationsender/queue/queue.go b/federationsender/queue/queue.go index aae6c53a0..386a3397f 100644 --- a/federationsender/queue/queue.go +++ b/federationsender/queue/queue.go @@ -52,6 +52,12 @@ func NewOutgoingQueues( } } +func (oqs *OutgoingQueues) getQueueIfExists(destination gomatrixserverlib.ServerName) *destinationQueue { + oqs.queuesMutex.Lock() + defer oqs.queuesMutex.Unlock() + return oqs.queues[destination] +} + func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *destinationQueue { oqs.queuesMutex.Lock() defer oqs.queuesMutex.Unlock() @@ -66,6 +72,7 @@ func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *d incomingPDUs: make(chan *gomatrixserverlib.HeaderedEvent, 128), incomingEDUs: make(chan *gomatrixserverlib.EDU, 128), incomingInvites: make(chan *gomatrixserverlib.InviteV2Request, 128), + retryServerCh: make(chan bool), } oqs.queues[destination] = oq } @@ -160,6 +167,15 @@ func (oqs *OutgoingQueues) SendEDU( return nil } +// RetryServer attempts to resend events to the given server if we had given up. +func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) { + q := oqs.getQueueIfExists(srv) + if q == nil { + return + } + q.retry() +} + // filterAndDedupeDests removes our own server from the list of destinations // and deduplicates any servers in the list that may appear more than once. func filterAndDedupeDests(origin gomatrixserverlib.ServerName, destinations []gomatrixserverlib.ServerName) ( diff --git a/federationsender/storage/storage_wasm.go b/federationsender/storage/storage_wasm.go index f593fd448..3d071bfef 100644 --- a/federationsender/storage/storage_wasm.go +++ b/federationsender/storage/storage_wasm.go @@ -33,7 +33,7 @@ func NewDatabase( } switch uri.Scheme { case "file": - return sqlite3.NewDatabase(dataSourceName) + return sqlite3.NewDatabase(uri.Path) case "postgres": return nil, fmt.Errorf("Cannot use postgres implementation") default: diff --git a/go.mod b/go.mod index bbac53eca..021cc02a4 100644 --- a/go.mod +++ b/go.mod @@ -16,9 +16,9 @@ require ( github.com/libp2p/go-libp2p-record v0.1.2 github.com/matrix-org/dugong v0.0.0-20171220115018-ea0a4690a0d5 github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 - github.com/matrix-org/go-sqlite3-js v0.0.0-20200326102434-98eda28055bd + github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 - github.com/matrix-org/gomatrixserverlib v0.0.0-20200511154227-5cc71d36632b + github.com/matrix-org/gomatrixserverlib v0.0.0-20200602125825-24ff01093eca github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 github.com/mattn/go-sqlite3 v2.0.2+incompatible diff --git a/go.sum b/go.sum index d487f42a3..3f235d51e 100644 --- a/go.sum +++ b/go.sum @@ -352,12 +352,12 @@ github.com/matrix-org/dugong v0.0.0-20171220115018-ea0a4690a0d5/go.mod h1:NgPCr+ github.com/matrix-org/dugong v0.0.0-20171220115018-ea0a4690a0d5/go.mod h1:NgPCr+UavRGH6n5jmdX8DuqFZ4JiCWIJoZiuhTRLSUg= github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 h1:eqE5OnGx9ZMWmrRbD3KF/3KtTunw0iQulI7YxOIdxo4= github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4/go.mod h1:3WluEZ9QXSwU30tWYqktnpC1x9mwZKx1r8uAv8Iq+a4= -github.com/matrix-org/go-sqlite3-js v0.0.0-20200326102434-98eda28055bd h1:C1FV4dRKF1uuGK8UH01+IoW6zZpfsTV1MvQimZvt418= -github.com/matrix-org/go-sqlite3-js v0.0.0-20200326102434-98eda28055bd/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= +github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 h1:Yb+Wlf/iHhWlLWd+kCgG+Fsg4Dc+xBl7hptfK7lD0zY= +github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bhrnp3Ky1qgx/fzCtCALOoGYylh2tpS9K4= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= -github.com/matrix-org/gomatrixserverlib v0.0.0-20200511154227-5cc71d36632b h1:nAmSc1KvQOumoRTz/LD68KyrB6Q5/6q7CmQ5Bswc2nM= -github.com/matrix-org/gomatrixserverlib v0.0.0-20200511154227-5cc71d36632b/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200602125825-24ff01093eca h1:s/dJePRDKjD1fGeoTnEYFqPmp1v7fC6GTd6iFwCKxw8= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200602125825-24ff01093eca/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f h1:pRz4VTiRCO4zPlEMc3ESdUOcW4PXHH4Kj+YDz1XyE+Y= github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f/go.mod h1:y0oDTjZDv5SM9a2rp3bl+CU+bvTRINQsdb7YlDql5Go= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo= diff --git a/internal/basecomponent/base.go b/internal/basecomponent/base.go index 24c09621e..94e6b9ee6 100644 --- a/internal/basecomponent/base.go +++ b/internal/basecomponent/base.go @@ -22,11 +22,8 @@ import ( "net/url" "time" - "golang.org/x/crypto/ed25519" - "github.com/matrix-org/dendrite/internal/caching" - "github.com/matrix-org/dendrite/internal/keydb" - "github.com/matrix-org/dendrite/internal/keydb/cache" + "github.com/matrix-org/dendrite/internal/httpapis" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/naffka" @@ -43,6 +40,7 @@ import ( federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/config" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + serverKeyAPI "github.com/matrix-org/dendrite/serverkeyapi/api" "github.com/sirupsen/logrus" _ "net/http/pprof" @@ -57,8 +55,9 @@ type BaseDendrite struct { componentName string tracerCloser io.Closer - // APIMux should be used to register new public matrix api endpoints - APIMux *mux.Router + // PublicAPIMux should be used to register new public matrix api endpoints + PublicAPIMux *mux.Router + InternalAPIMux *mux.Router EnableHTTPAPIs bool httpClient *http.Client Cfg *config.Dendrite @@ -103,14 +102,17 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, enableHTTPAPIs Host: fmt.Sprintf("%s:%d", cfg.Proxy.Host, cfg.Proxy.Port), })} } - + + httpmux := mux.NewRouter() + return &BaseDendrite{ componentName: componentName, EnableHTTPAPIs: enableHTTPAPIs, tracerCloser: closer, Cfg: cfg, ImmutableCache: cache, - APIMux: mux.NewRouter().UseEncodedPath(), + PublicAPIMux: httpmux.PathPrefix(httpapis.PublicPathPrefix).Subrouter().UseEncodedPath(), + InternalAPIMux: httpmux.PathPrefix(httpapis.InternalPathPrefix).Subrouter().UseEncodedPath(), httpClient: &client, KafkaConsumer: kafkaConsumer, KafkaProducer: kafkaProducer, @@ -162,6 +164,20 @@ func (b *BaseDendrite) CreateHTTPFederationSenderAPIs() federationSenderAPI.Fede return f } +// CreateHTTPServerKeyAPIs returns ServerKeyInternalAPI for hitting the server key +// API over HTTP +func (b *BaseDendrite) CreateHTTPServerKeyAPIs() serverKeyAPI.ServerKeyInternalAPI { + f, err := serverKeyAPI.NewServerKeyInternalAPIHTTP( + b.Cfg.ServerKeyAPIURL(), + b.httpClient, + b.ImmutableCache, + ) + if err != nil { + logrus.WithError(err).Panic("NewServerKeyInternalAPIHTTP failed", b.httpClient) + } + return f +} + // CreateDeviceDB creates a new instance of the device database. Should only be // called once per component. func (b *BaseDendrite) CreateDeviceDB() devices.Database { @@ -184,27 +200,6 @@ func (b *BaseDendrite) CreateAccountsDB() accounts.Database { return db } -// CreateKeyDB creates a new instance of the key database. Should only be called -// once per component. -func (b *BaseDendrite) CreateKeyDB() keydb.Database { - db, err := keydb.NewDatabase( - string(b.Cfg.Database.ServerKey), - b.Cfg.DbProperties(), - b.Cfg.Matrix.ServerName, - b.Cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey), - b.Cfg.Matrix.KeyID, - ) - if err != nil { - logrus.WithError(err).Panicf("failed to connect to keys db") - } - - cachedDB, err := cache.NewKeyDatabase(db, b.ImmutableCache) - if err != nil { - logrus.WithError(err).Panicf("failed to create key cache wrapper") - } - return cachedDB -} - // CreateFederationClient creates a new federation client. Should only be called // once per component. func (b *BaseDendrite) CreateFederationClient() *gomatrixserverlib.FederationClient { @@ -230,7 +225,13 @@ func (b *BaseDendrite) SetupAndServeHTTP(bindaddr string, listenaddr string) { WriteTimeout: HTTPServerTimeout, } - internal.SetupHTTPAPI(http.DefaultServeMux, internal.WrapHandlerInCORS(b.APIMux), b.Cfg) + internal.SetupHTTPAPI( + http.DefaultServeMux, + b.PublicAPIMux, + b.InternalAPIMux, + b.Cfg, + b.EnableHTTPAPIs, + ) logrus.Infof("Starting %s server on %s", b.componentName, serv.Addr) err := serv.ListenAndServe() @@ -264,7 +265,15 @@ func setupNaffka(cfg *config.Dendrite) (sarama.Consumer, sarama.SyncProducer) { uri, err := url.Parse(string(cfg.Database.Naffka)) if err != nil || uri.Scheme == "file" { - db, err = sqlutil.Open(internal.SQLiteDriverName(), string(cfg.Database.Naffka), nil) + var cs string + if uri.Opaque != "" { // file:filename.db + cs = uri.Opaque + } else if uri.Path != "" { // file:///path/to/filename.db + cs = uri.Path + } else { + logrus.Panic("file uri has no filename") + } + db, err = sqlutil.Open(internal.SQLiteDriverName(), cs, nil) if err != nil { logrus.WithError(err).Panic("Failed to open naffka database") } diff --git a/internal/config/config.go b/internal/config/config.go index 32de41cc5..8e1789980 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -152,6 +152,8 @@ type Dendrite struct { OutputClientData Topic `yaml:"output_client_data"` // Topic for eduserver/api.OutputTypingEvent events. OutputTypingEvent Topic `yaml:"output_typing_event"` + // Topic for eduserver/api.OutputSendToDeviceEvent events. + OutputSendToDeviceEvent Topic `yaml:"output_send_to_device_event"` // Topic for user updates (profile, presence) UserUpdates Topic `yaml:"user_updates"` } @@ -223,6 +225,7 @@ type Dendrite struct { MediaAPI Address `yaml:"media_api"` ClientAPI Address `yaml:"client_api"` FederationAPI Address `yaml:"federation_api"` + ServerKeyAPI Address `yaml:"server_key_api"` AppServiceAPI Address `yaml:"appservice_api"` SyncAPI Address `yaml:"sync_api"` RoomServer Address `yaml:"room_server"` @@ -237,6 +240,7 @@ type Dendrite struct { MediaAPI Address `yaml:"media_api"` ClientAPI Address `yaml:"client_api"` FederationAPI Address `yaml:"federation_api"` + ServerKeyAPI Address `yaml:"server_key_api"` AppServiceAPI Address `yaml:"appservice_api"` SyncAPI Address `yaml:"sync_api"` RoomServer Address `yaml:"room_server"` @@ -626,6 +630,7 @@ func (config *Dendrite) checkListen(configErrs *configErrors) { checkNotEmpty(configErrs, "listen.sync_api", string(config.Listen.SyncAPI)) checkNotEmpty(configErrs, "listen.room_server", string(config.Listen.RoomServer)) checkNotEmpty(configErrs, "listen.edu_server", string(config.Listen.EDUServer)) + checkNotEmpty(configErrs, "listen.server_key_api", string(config.Listen.EDUServer)) } // checkLogging verifies the parameters logging.* are valid. @@ -757,6 +762,15 @@ func (config *Dendrite) FederationSenderURL() string { return "http://" + string(config.Listen.FederationSender) } +// FederationSenderURL returns an HTTP URL for where the federation sender is listening. +func (config *Dendrite) ServerKeyAPIURL() string { + // Hard code the server key API server to talk HTTP for now. + // If we support HTTPS we need to think of a practical way to do certificate validation. + // People setting up servers shouldn't need to get a certificate valid for the public + // internet for an internal API. + return "http://" + string(config.Listen.ServerKeyAPI) +} + // SetupTracing configures the opentracing using the supplied configuration. func (config *Dendrite) SetupTracing(serviceName string) (closer io.Closer, err error) { if !config.Tracing.Enabled { diff --git a/internal/http/http.go b/internal/http/http.go index 3c6475443..2b1891403 100644 --- a/internal/http/http.go +++ b/internal/http/http.go @@ -6,7 +6,10 @@ import ( "encoding/json" "fmt" "net/http" + "net/url" + "strings" + "github.com/matrix-org/dendrite/internal/httpapis" opentracing "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" ) @@ -21,6 +24,14 @@ func PostJSON( return err } + parsedAPIURL, err := url.Parse(apiURL) + if err != nil { + return err + } + + parsedAPIURL.Path = httpapis.InternalPathPrefix + strings.TrimLeft(parsedAPIURL.Path, "/") + apiURL = parsedAPIURL.String() + req, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewReader(jsonBytes)) if err != nil { return err @@ -48,10 +59,10 @@ func PostJSON( var errorBody struct { Message string `json:"message"` } - if err = json.NewDecoder(res.Body).Decode(&errorBody); err != nil { - return err + if msgerr := json.NewDecoder(res.Body).Decode(&errorBody); msgerr == nil { + return fmt.Errorf("Internal API: %d from %s: %s", res.StatusCode, apiURL, errorBody.Message) } - return fmt.Errorf("api: %d: %s", res.StatusCode, errorBody.Message) + return fmt.Errorf("Internal API: %d from %s", res.StatusCode, apiURL) } return json.NewDecoder(res.Body).Decode(response) } diff --git a/internal/httpapi.go b/internal/httpapi.go index b4c53f58d..bacd1768a 100644 --- a/internal/httpapi.go +++ b/internal/httpapi.go @@ -1,17 +1,22 @@ package internal import ( + "context" "io" "net/http" "net/http/httptest" "net/http/httputil" "os" "strings" + "sync" "time" + "github.com/gorilla/mux" "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + federationsenderAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/httpapis" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" opentracing "github.com/opentracing/opentracing-go" @@ -168,6 +173,7 @@ func MakeFedAPI( metricsName string, serverName gomatrixserverlib.ServerName, keyRing gomatrixserverlib.KeyRing, + wakeup *FederationWakeups, f func(*http.Request, *gomatrixserverlib.FederationRequest) util.JSONResponse, ) http.Handler { h := func(req *http.Request) util.JSONResponse { @@ -177,18 +183,48 @@ func MakeFedAPI( if fedReq == nil { return errResp } + go wakeup.Wakeup(req.Context(), fedReq.Origin()) return f(req, fedReq) } return MakeExternalAPI(metricsName, h) } +type FederationWakeups struct { + FsAPI federationsenderAPI.FederationSenderInternalAPI + origins sync.Map +} + +func (f *FederationWakeups) Wakeup(ctx context.Context, origin gomatrixserverlib.ServerName) { + key, keyok := f.origins.Load(origin) + if keyok { + lastTime, ok := key.(time.Time) + if ok && time.Since(lastTime) < time.Minute { + return + } + } + aliveReq := federationsenderAPI.PerformServersAliveRequest{ + Servers: []gomatrixserverlib.ServerName{origin}, + } + aliveRes := federationsenderAPI.PerformServersAliveResponse{} + if err := f.FsAPI.PerformServersAlive(ctx, &aliveReq, &aliveRes); err != nil { + util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ + "origin": origin, + }).Warn("incoming federation request failed to notify server alive") + } else { + f.origins.Store(origin, time.Now()) + } +} + // SetupHTTPAPI registers an HTTP API mux under /api and sets up a metrics // listener. -func SetupHTTPAPI(servMux *http.ServeMux, apiMux http.Handler, cfg *config.Dendrite) { +func SetupHTTPAPI(servMux *http.ServeMux, publicApiMux *mux.Router, internalApiMux *mux.Router, cfg *config.Dendrite, enableHTTPAPIs bool) { if cfg.Metrics.Enabled { servMux.Handle("/metrics", WrapHandlerInBasicAuth(promhttp.Handler(), cfg.Metrics.BasicAuth)) } - servMux.Handle("/api/", http.StripPrefix("/api", apiMux)) + if enableHTTPAPIs { + servMux.Handle(httpapis.InternalPathPrefix, internalApiMux) + } + servMux.Handle(httpapis.PublicPathPrefix, WrapHandlerInCORS(publicApiMux)) } // WrapHandlerInBasicAuth adds basic auth to a handler. Only used for /metrics diff --git a/internal/httpapis/paths.go b/internal/httpapis/paths.go new file mode 100644 index 000000000..8adec2dff --- /dev/null +++ b/internal/httpapis/paths.go @@ -0,0 +1,6 @@ +package httpapis + +const ( + PublicPathPrefix = "/_matrix/" + InternalPathPrefix = "/api/" +) diff --git a/internal/keydb/keyring.go b/internal/keydb/keyring.go deleted file mode 100644 index d0b1904ed..000000000 --- a/internal/keydb/keyring.go +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2017 New Vector Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package keydb - -import ( - "encoding/base64" - - "github.com/matrix-org/dendrite/internal/config" - "github.com/matrix-org/gomatrixserverlib" - "github.com/sirupsen/logrus" - "golang.org/x/crypto/ed25519" -) - -// CreateKeyRing creates and configures a KeyRing object. -// -// It creates the necessary key fetchers and collects them into a KeyRing -// backed by the given KeyDatabase. -func CreateKeyRing(client gomatrixserverlib.Client, - keyDB gomatrixserverlib.KeyDatabase, - cfg config.KeyPerspectives) gomatrixserverlib.KeyRing { - - fetchers := gomatrixserverlib.KeyRing{ - KeyFetchers: []gomatrixserverlib.KeyFetcher{ - &gomatrixserverlib.DirectKeyFetcher{ - Client: client, - }, - }, - KeyDatabase: keyDB, - } - - logrus.Info("Enabled direct key fetcher") - - var b64e = base64.StdEncoding.WithPadding(base64.NoPadding) - for _, ps := range cfg { - perspective := &gomatrixserverlib.PerspectiveKeyFetcher{ - PerspectiveServerName: ps.ServerName, - PerspectiveServerKeys: map[gomatrixserverlib.KeyID]ed25519.PublicKey{}, - Client: client, - } - - for _, key := range ps.Keys { - rawkey, err := b64e.DecodeString(key.PublicKey) - if err != nil { - logrus.WithError(err).WithFields(logrus.Fields{ - "server_name": ps.ServerName, - "public_key": key.PublicKey, - }).Warn("Couldn't parse perspective key") - continue - } - perspective.PerspectiveServerKeys[key.KeyID] = rawkey - } - - fetchers.KeyFetchers = append(fetchers.KeyFetchers, perspective) - - logrus.WithFields(logrus.Fields{ - "server_name": ps.ServerName, - "num_public_keys": len(ps.Keys), - }).Info("Enabled perspective key fetcher") - } - - return fetchers -} diff --git a/internal/sql.go b/internal/sql.go index d6a5a3086..e3c10afca 100644 --- a/internal/sql.go +++ b/internal/sql.go @@ -1,4 +1,6 @@ // Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,11 +18,17 @@ package internal import ( "database/sql" + "errors" "fmt" "runtime" "time" + + "go.uber.org/atomic" ) +// ErrUserExists is returned if a username already exists in the database. +var ErrUserExists = errors.New("Username already exists") + // A Transaction is something that can be committed or rolledback. type Transaction interface { // Commit the transaction @@ -107,3 +115,60 @@ type DbProperties interface { MaxOpenConns() int ConnMaxLifetime() time.Duration } + +// TransactionWriter allows queuing database writes so that you don't +// contend on database locks in, e.g. SQLite. Only one task will run +// at a time on a given TransactionWriter. +type TransactionWriter struct { + running atomic.Bool + todo chan transactionWriterTask +} + +func NewTransactionWriter() *TransactionWriter { + return &TransactionWriter{ + todo: make(chan transactionWriterTask), + } +} + +// transactionWriterTask represents a specific task. +type transactionWriterTask struct { + db *sql.DB + f func(txn *sql.Tx) error + wait chan error +} + +// Do queues a task to be run by a TransactionWriter. The function +// provided will be ran within a transaction as supplied by the +// database parameter. This will block until the task is finished. +func (w *TransactionWriter) Do(db *sql.DB, f func(txn *sql.Tx) error) error { + if w.todo == nil { + return errors.New("not initialised") + } + if !w.running.Load() { + go w.run() + } + task := transactionWriterTask{ + db: db, + f: f, + wait: make(chan error, 1), + } + w.todo <- task + return <-task.wait +} + +// run processes the tasks for a given transaction writer. Only one +// of these goroutines will run at a time. A transaction will be +// opened using the database object from the task and then this will +// be passed as a parameter to the task function. +func (w *TransactionWriter) run() { + if !w.running.CAS(false, true) { + return + } + defer w.running.Store(false) + for task := range w.todo { + task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error { + return task.f(txn) + }) + close(task.wait) + } +} diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 0f450165f..4e785cbc7 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -28,5 +28,5 @@ func SetupKeyServerComponent( deviceDB devices.Database, accountsDB accounts.Database, ) { - routing.Setup(base.APIMux, base.Cfg, accountsDB, deviceDB) + routing.Setup(base.PublicAPIMux, base.Cfg, accountsDB, deviceDB) } diff --git a/keyserver/routing/routing.go b/keyserver/routing/routing.go index 3ca68ade4..2d1122c62 100644 --- a/keyserver/routing/routing.go +++ b/keyserver/routing/routing.go @@ -27,7 +27,7 @@ import ( "github.com/matrix-org/util" ) -const pathPrefixR0 = "/_matrix/client/r0" +const pathPrefixR0 = "/client/r0" // Setup registers HTTP handlers with the given ServeMux. It also supplies the given http.Client // to clients which need to make outbound HTTP requests. @@ -36,11 +36,11 @@ const pathPrefixR0 = "/_matrix/client/r0" // applied: // nolint: gocyclo func Setup( - apiMux *mux.Router, cfg *config.Dendrite, + publicAPIMux *mux.Router, cfg *config.Dendrite, accountDB accounts.Database, deviceDB devices.Database, ) { - r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter() + r0mux := publicAPIMux.PathPrefix(pathPrefixR0).Subrouter() authData := auth.Data{ AccountDB: accountDB, diff --git a/mediaapi/mediaapi.go b/mediaapi/mediaapi.go index 38572daf8..b5bec3907 100644 --- a/mediaapi/mediaapi.go +++ b/mediaapi/mediaapi.go @@ -35,6 +35,6 @@ func SetupMediaAPIComponent( } routing.Setup( - base.APIMux, base.Cfg, mediaDB, deviceDB, gomatrixserverlib.NewClient(), + base.PublicAPIMux, base.Cfg, mediaDB, deviceDB, gomatrixserverlib.NewClient(), ) } diff --git a/mediaapi/routing/routing.go b/mediaapi/routing/routing.go index 4f765d173..11e4ecad5 100644 --- a/mediaapi/routing/routing.go +++ b/mediaapi/routing/routing.go @@ -33,7 +33,7 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" ) -const pathPrefixR0 = "/_matrix/media/r0" +const pathPrefixR0 = "/media/r0" // Setup registers the media API HTTP handlers // @@ -41,13 +41,13 @@ const pathPrefixR0 = "/_matrix/media/r0" // applied: // nolint: gocyclo func Setup( - apiMux *mux.Router, + publicAPIMux *mux.Router, cfg *config.Dendrite, db storage.Database, deviceDB devices.Database, client *gomatrixserverlib.Client, ) { - r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter() + r0mux := publicAPIMux.PathPrefix(pathPrefixR0).Subrouter() activeThumbnailGeneration := &types.ActiveThumbnailGeneration{ PathToResult: map[string]*types.ThumbnailGenerationResult{}, diff --git a/mediaapi/storage/storage_wasm.go b/mediaapi/storage/storage_wasm.go index 78de2cb82..aa188997f 100644 --- a/mediaapi/storage/storage_wasm.go +++ b/mediaapi/storage/storage_wasm.go @@ -35,7 +35,7 @@ func Open( case "postgres": return nil, fmt.Errorf("Cannot use postgres implementation") case "file": - return sqlite3.Open(dataSourceName) + return sqlite3.Open(uri.Path) default: return nil, fmt.Errorf("Cannot use postgres implementation") } diff --git a/publicroomsapi/directory/public_rooms.go b/publicroomsapi/directory/public_rooms.go index 7bd6740eb..df9df8ff9 100644 --- a/publicroomsapi/directory/public_rooms.go +++ b/publicroomsapi/directory/public_rooms.go @@ -16,7 +16,9 @@ package directory import ( "context" + "math/rand" "net/http" + "sort" "strconv" "sync" "time" @@ -96,6 +98,28 @@ func GetPostPublicRoomsWithExternal( // downcasting `limit` is safe as we know it isn't bigger than request.Limit which is int16 fedRooms := bulkFetchPublicRoomsFromServers(req.Context(), fedClient, extRoomsProvider.Homeservers(), int16(limit)) response.Chunk = append(response.Chunk, fedRooms...) + + // de-duplicate rooms with the same room ID. We can join the room via any of these aliases as we know these servers + // are alive and well, so we arbitrarily pick one (purposefully shuffling them to spread the load a bit) + var publicRooms []gomatrixserverlib.PublicRoom + haveRoomIDs := make(map[string]bool) + rand.Shuffle(len(response.Chunk), func(i, j int) { + response.Chunk[i], response.Chunk[j] = response.Chunk[j], response.Chunk[i] + }) + for _, r := range response.Chunk { + if haveRoomIDs[r.RoomID] { + continue + } + haveRoomIDs[r.RoomID] = true + publicRooms = append(publicRooms, r) + } + // sort by member count + sort.SliceStable(publicRooms, func(i, j int) bool { + return publicRooms[i].JoinedMembersCount > publicRooms[j].JoinedMembersCount + }) + + response.Chunk = publicRooms + return util.JSONResponse{ Code: http.StatusOK, JSON: response, diff --git a/publicroomsapi/publicroomsapi.go b/publicroomsapi/publicroomsapi.go index f767a5303..b53351ffd 100644 --- a/publicroomsapi/publicroomsapi.go +++ b/publicroomsapi/publicroomsapi.go @@ -43,5 +43,5 @@ func SetupPublicRoomsAPIComponent( logrus.WithError(err).Panic("failed to start public rooms server consumer") } - routing.Setup(base.APIMux, deviceDB, publicRoomsDB, rsAPI, fedClient, extRoomsProvider) + routing.Setup(base.PublicAPIMux, deviceDB, publicRoomsDB, rsAPI, fedClient, extRoomsProvider) } diff --git a/publicroomsapi/routing/routing.go b/publicroomsapi/routing/routing.go index 39a9dcca1..ee1434392 100644 --- a/publicroomsapi/routing/routing.go +++ b/publicroomsapi/routing/routing.go @@ -31,7 +31,7 @@ import ( "github.com/matrix-org/util" ) -const pathPrefixR0 = "/_matrix/client/r0" +const pathPrefixR0 = "/client/r0" // Setup configures the given mux with publicroomsapi server listeners // @@ -39,10 +39,10 @@ const pathPrefixR0 = "/_matrix/client/r0" // applied: // nolint: gocyclo func Setup( - apiMux *mux.Router, deviceDB devices.Database, publicRoomsDB storage.Database, rsAPI api.RoomserverInternalAPI, + publicAPIMux *mux.Router, deviceDB devices.Database, publicRoomsDB storage.Database, rsAPI api.RoomserverInternalAPI, fedClient *gomatrixserverlib.FederationClient, extRoomsProvider types.ExternalPublicRoomsProvider, ) { - r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter() + r0mux := publicAPIMux.PathPrefix(pathPrefixR0).Subrouter() authData := auth.Data{ AccountDB: nil, @@ -79,7 +79,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) // Federation - TODO: should this live here or in federation API? It's sure easier if it's here so here it is. - apiMux.Handle("/_matrix/federation/v1/publicRooms", + publicAPIMux.Handle("/federation/v1/publicRooms", internal.MakeExternalAPI("federation_public_rooms", func(req *http.Request) util.JSONResponse { return directory.GetPostPublicRooms(req, publicRoomsDB) }), diff --git a/publicroomsapi/storage/postgres/storage.go b/publicroomsapi/storage/postgres/storage.go index 691199a0b..5f0f551a7 100644 --- a/publicroomsapi/storage/postgres/storage.go +++ b/publicroomsapi/storage/postgres/storage.go @@ -30,20 +30,22 @@ import ( type PublicRoomsServerDatabase struct { db *sql.DB internal.PartitionOffsetStatements - statements publicRoomsStatements + statements publicRoomsStatements + localServerName gomatrixserverlib.ServerName } type attributeValue interface{} // NewPublicRoomsServerDatabase creates a new public rooms server database. -func NewPublicRoomsServerDatabase(dataSourceName string, dbProperties internal.DbProperties) (*PublicRoomsServerDatabase, error) { +func NewPublicRoomsServerDatabase(dataSourceName string, dbProperties internal.DbProperties, localServerName gomatrixserverlib.ServerName) (*PublicRoomsServerDatabase, error) { var db *sql.DB var err error if db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil { return nil, err } storage := PublicRoomsServerDatabase{ - db: db, + db: db, + localServerName: localServerName, } if err = storage.PartitionOffsetStatements.Prepare(db, "publicroomsapi"); err != nil { return nil, err @@ -243,6 +245,9 @@ func (d *PublicRoomsServerDatabase) updateBooleanAttribute( func (d *PublicRoomsServerDatabase) updateRoomAliases( ctx context.Context, aliasesEvent gomatrixserverlib.Event, ) error { + if aliasesEvent.StateKey() == nil || *aliasesEvent.StateKey() != string(d.localServerName) { + return nil // only store our own aliases + } var content internal.AliasesContent if err := json.Unmarshal(aliasesEvent.Content(), &content); err != nil { return err diff --git a/publicroomsapi/storage/sqlite3/storage.go b/publicroomsapi/storage/sqlite3/storage.go index b81222dd3..898101127 100644 --- a/publicroomsapi/storage/sqlite3/storage.go +++ b/publicroomsapi/storage/sqlite3/storage.go @@ -32,20 +32,22 @@ import ( type PublicRoomsServerDatabase struct { db *sql.DB internal.PartitionOffsetStatements - statements publicRoomsStatements + statements publicRoomsStatements + localServerName gomatrixserverlib.ServerName } type attributeValue interface{} // NewPublicRoomsServerDatabase creates a new public rooms server database. -func NewPublicRoomsServerDatabase(dataSourceName string) (*PublicRoomsServerDatabase, error) { +func NewPublicRoomsServerDatabase(dataSourceName string, localServerName gomatrixserverlib.ServerName) (*PublicRoomsServerDatabase, error) { var db *sql.DB var err error if db, err = sqlutil.Open(internal.SQLiteDriverName(), dataSourceName, nil); err != nil { return nil, err } storage := PublicRoomsServerDatabase{ - db: db, + db: db, + localServerName: localServerName, } if err = storage.PartitionOffsetStatements.Prepare(db, "publicroomsapi"); err != nil { return nil, err @@ -245,6 +247,9 @@ func (d *PublicRoomsServerDatabase) updateBooleanAttribute( func (d *PublicRoomsServerDatabase) updateRoomAliases( ctx context.Context, aliasesEvent gomatrixserverlib.Event, ) error { + if aliasesEvent.StateKey() == nil || *aliasesEvent.StateKey() != string(d.localServerName) { + return nil // only store our own aliases + } var content internal.AliasesContent if err := json.Unmarshal(aliasesEvent.Content(), &content); err != nil { return err diff --git a/publicroomsapi/storage/storage.go b/publicroomsapi/storage/storage.go index 978d9a3c9..507b9fa6c 100644 --- a/publicroomsapi/storage/storage.go +++ b/publicroomsapi/storage/storage.go @@ -22,23 +22,24 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/publicroomsapi/storage/postgres" "github.com/matrix-org/dendrite/publicroomsapi/storage/sqlite3" + "github.com/matrix-org/gomatrixserverlib" ) const schemePostgres = "postgres" const schemeFile = "file" // NewPublicRoomsServerDatabase opens a database connection. -func NewPublicRoomsServerDatabase(dataSourceName string, dbProperties internal.DbProperties) (Database, error) { +func NewPublicRoomsServerDatabase(dataSourceName string, dbProperties internal.DbProperties, localServerName gomatrixserverlib.ServerName) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { - return postgres.NewPublicRoomsServerDatabase(dataSourceName, dbProperties) + return postgres.NewPublicRoomsServerDatabase(dataSourceName, dbProperties, localServerName) } switch uri.Scheme { case schemePostgres: - return postgres.NewPublicRoomsServerDatabase(dataSourceName, dbProperties) + return postgres.NewPublicRoomsServerDatabase(dataSourceName, dbProperties, localServerName) case schemeFile: - return sqlite3.NewPublicRoomsServerDatabase(dataSourceName) + return sqlite3.NewPublicRoomsServerDatabase(dataSourceName, localServerName) default: - return postgres.NewPublicRoomsServerDatabase(dataSourceName, dbProperties) + return postgres.NewPublicRoomsServerDatabase(dataSourceName, dbProperties, localServerName) } } diff --git a/publicroomsapi/storage/storage_wasm.go b/publicroomsapi/storage/storage_wasm.go index d00c339d8..fa3356c03 100644 --- a/publicroomsapi/storage/storage_wasm.go +++ b/publicroomsapi/storage/storage_wasm.go @@ -19,10 +19,11 @@ import ( "net/url" "github.com/matrix-org/dendrite/publicroomsapi/storage/sqlite3" + "github.com/matrix-org/gomatrixserverlib" ) // NewPublicRoomsServerDatabase opens a database connection. -func NewPublicRoomsServerDatabase(dataSourceName string) (Database, error) { +func NewPublicRoomsServerDatabase(dataSourceName string, localServerName gomatrixserverlib.ServerName) (Database, error) { uri, err := url.Parse(dataSourceName) if err != nil { return nil, err @@ -31,7 +32,7 @@ func NewPublicRoomsServerDatabase(dataSourceName string) (Database, error) { case "postgres": return nil, fmt.Errorf("Cannot use postgres implementation") case "file": - return sqlite3.NewPublicRoomsServerDatabase(dataSourceName) + return sqlite3.NewPublicRoomsServerDatabase(uri.Path, localServerName) default: return nil, fmt.Errorf("Cannot use postgres implementation") } diff --git a/roomserver/api/alias.go b/roomserver/api/alias.go index 93975d36c..54d2c633b 100644 --- a/roomserver/api/alias.go +++ b/roomserver/api/alias.go @@ -85,19 +85,19 @@ type RemoveRoomAliasRequest struct { type RemoveRoomAliasResponse struct{} // RoomserverSetRoomAliasPath is the HTTP path for the SetRoomAlias API. -const RoomserverSetRoomAliasPath = "/api/roomserver/setRoomAlias" +const RoomserverSetRoomAliasPath = "/roomserver/setRoomAlias" // RoomserverGetRoomIDForAliasPath is the HTTP path for the GetRoomIDForAlias API. -const RoomserverGetRoomIDForAliasPath = "/api/roomserver/GetRoomIDForAlias" +const RoomserverGetRoomIDForAliasPath = "/roomserver/GetRoomIDForAlias" // RoomserverGetAliasesForRoomIDPath is the HTTP path for the GetAliasesForRoomID API. -const RoomserverGetAliasesForRoomIDPath = "/api/roomserver/GetAliasesForRoomID" +const RoomserverGetAliasesForRoomIDPath = "/roomserver/GetAliasesForRoomID" // RoomserverGetCreatorIDForAliasPath is the HTTP path for the GetCreatorIDForAlias API. -const RoomserverGetCreatorIDForAliasPath = "/api/roomserver/GetCreatorIDForAlias" +const RoomserverGetCreatorIDForAliasPath = "/roomserver/GetCreatorIDForAlias" // RoomserverRemoveRoomAliasPath is the HTTP path for the RemoveRoomAlias API. -const RoomserverRemoveRoomAliasPath = "/api/roomserver/removeRoomAlias" +const RoomserverRemoveRoomAliasPath = "/roomserver/removeRoomAlias" // SetRoomAlias implements RoomserverAliasAPI func (h *httpRoomserverInternalAPI) SetRoomAlias( diff --git a/roomserver/api/input.go b/roomserver/api/input.go index f1eec75a8..d35ead764 100644 --- a/roomserver/api/input.go +++ b/roomserver/api/input.go @@ -103,7 +103,7 @@ type InputRoomEventsResponse struct { } // RoomserverInputRoomEventsPath is the HTTP path for the InputRoomEvents API. -const RoomserverInputRoomEventsPath = "/api/roomserver/inputRoomEvents" +const RoomserverInputRoomEventsPath = "/roomserver/inputRoomEvents" // InputRoomEvents implements RoomserverInputAPI func (h *httpRoomserverInternalAPI) InputRoomEvents( diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index 93f36f451..f5afd67bf 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -10,10 +10,10 @@ import ( const ( // RoomserverPerformJoinPath is the HTTP path for the PerformJoin API. - RoomserverPerformJoinPath = "/api/roomserver/performJoin" + RoomserverPerformJoinPath = "/roomserver/performJoin" // RoomserverPerformLeavePath is the HTTP path for the PerformLeave API. - RoomserverPerformLeavePath = "/api/roomserver/performLeave" + RoomserverPerformLeavePath = "/roomserver/performLeave" ) type PerformJoinRequest struct { @@ -24,6 +24,7 @@ type PerformJoinRequest struct { } type PerformJoinResponse struct { + RoomID string `json:"room_id"` } func (h *httpRoomserverInternalAPI) PerformJoin( diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 98f46c771..916ecb362 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -273,40 +273,40 @@ type QueryRoomVersionForRoomResponse struct { } // RoomserverQueryLatestEventsAndStatePath is the HTTP path for the QueryLatestEventsAndState API. -const RoomserverQueryLatestEventsAndStatePath = "/api/roomserver/queryLatestEventsAndState" +const RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState" // RoomserverQueryStateAfterEventsPath is the HTTP path for the QueryStateAfterEvents API. -const RoomserverQueryStateAfterEventsPath = "/api/roomserver/queryStateAfterEvents" +const RoomserverQueryStateAfterEventsPath = "/roomserver/queryStateAfterEvents" // RoomserverQueryEventsByIDPath is the HTTP path for the QueryEventsByID API. -const RoomserverQueryEventsByIDPath = "/api/roomserver/queryEventsByID" +const RoomserverQueryEventsByIDPath = "/roomserver/queryEventsByID" // RoomserverQueryMembershipForUserPath is the HTTP path for the QueryMembershipForUser API. -const RoomserverQueryMembershipForUserPath = "/api/roomserver/queryMembershipForUser" +const RoomserverQueryMembershipForUserPath = "/roomserver/queryMembershipForUser" // RoomserverQueryMembershipsForRoomPath is the HTTP path for the QueryMembershipsForRoom API -const RoomserverQueryMembershipsForRoomPath = "/api/roomserver/queryMembershipsForRoom" +const RoomserverQueryMembershipsForRoomPath = "/roomserver/queryMembershipsForRoom" // RoomserverQueryInvitesForUserPath is the HTTP path for the QueryInvitesForUser API -const RoomserverQueryInvitesForUserPath = "/api/roomserver/queryInvitesForUser" +const RoomserverQueryInvitesForUserPath = "/roomserver/queryInvitesForUser" // RoomserverQueryServerAllowedToSeeEventPath is the HTTP path for the QueryServerAllowedToSeeEvent API -const RoomserverQueryServerAllowedToSeeEventPath = "/api/roomserver/queryServerAllowedToSeeEvent" +const RoomserverQueryServerAllowedToSeeEventPath = "/roomserver/queryServerAllowedToSeeEvent" // RoomserverQueryMissingEventsPath is the HTTP path for the QueryMissingEvents API -const RoomserverQueryMissingEventsPath = "/api/roomserver/queryMissingEvents" +const RoomserverQueryMissingEventsPath = "/roomserver/queryMissingEvents" // RoomserverQueryStateAndAuthChainPath is the HTTP path for the QueryStateAndAuthChain API -const RoomserverQueryStateAndAuthChainPath = "/api/roomserver/queryStateAndAuthChain" +const RoomserverQueryStateAndAuthChainPath = "/roomserver/queryStateAndAuthChain" // RoomserverQueryBackfillPath is the HTTP path for the QueryBackfillPath API -const RoomserverQueryBackfillPath = "/api/roomserver/queryBackfill" +const RoomserverQueryBackfillPath = "/roomserver/queryBackfill" // RoomserverQueryRoomVersionCapabilitiesPath is the HTTP path for the QueryRoomVersionCapabilities API -const RoomserverQueryRoomVersionCapabilitiesPath = "/api/roomserver/queryRoomVersionCapabilities" +const RoomserverQueryRoomVersionCapabilitiesPath = "/roomserver/queryRoomVersionCapabilities" // RoomserverQueryRoomVersionForRoomPath is the HTTP path for the QueryRoomVersionForRoom API -const RoomserverQueryRoomVersionForRoomPath = "/api/roomserver/queryRoomVersionForRoom" +const RoomserverQueryRoomVersionForRoomPath = "/roomserver/queryRoomVersionForRoom" // QueryLatestEventsAndState implements RoomserverQueryAPI func (h *httpRoomserverInternalAPI) QueryLatestEventsAndState( diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 4d223cf36..248e457d1 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -6,6 +6,7 @@ import ( "sync" "github.com/Shopify/sarama" + "github.com/gorilla/mux" fsAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/caching" @@ -32,8 +33,8 @@ type RoomserverInternalAPI struct { // SetupHTTP adds the RoomserverInternalAPI handlers to the http.ServeMux. // nolint: gocyclo -func (r *RoomserverInternalAPI) SetupHTTP(servMux *http.ServeMux) { - servMux.Handle(api.RoomserverInputRoomEventsPath, +func (r *RoomserverInternalAPI) SetupHTTP(internalAPIMux *mux.Router) { + internalAPIMux.Handle(api.RoomserverInputRoomEventsPath, internal.MakeInternalAPI("inputRoomEvents", func(req *http.Request) util.JSONResponse { var request api.InputRoomEventsRequest var response api.InputRoomEventsResponse @@ -46,7 +47,7 @@ func (r *RoomserverInternalAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - servMux.Handle(api.RoomserverPerformJoinPath, + internalAPIMux.Handle(api.RoomserverPerformJoinPath, internal.MakeInternalAPI("performJoin", func(req *http.Request) util.JSONResponse { var request api.PerformJoinRequest var response api.PerformJoinResponse @@ -59,7 +60,7 @@ func (r *RoomserverInternalAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - servMux.Handle(api.RoomserverPerformLeavePath, + internalAPIMux.Handle(api.RoomserverPerformLeavePath, internal.MakeInternalAPI("performLeave", func(req *http.Request) util.JSONResponse { var request api.PerformLeaveRequest var response api.PerformLeaveResponse @@ -72,7 +73,7 @@ func (r *RoomserverInternalAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - servMux.Handle( + internalAPIMux.Handle( api.RoomserverQueryLatestEventsAndStatePath, internal.MakeInternalAPI("queryLatestEventsAndState", func(req *http.Request) util.JSONResponse { var request api.QueryLatestEventsAndStateRequest @@ -86,7 +87,7 @@ func (r *RoomserverInternalAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - servMux.Handle( + internalAPIMux.Handle( api.RoomserverQueryStateAfterEventsPath, internal.MakeInternalAPI("queryStateAfterEvents", func(req *http.Request) util.JSONResponse { var request api.QueryStateAfterEventsRequest @@ -100,7 +101,7 @@ func (r *RoomserverInternalAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - servMux.Handle( + internalAPIMux.Handle( api.RoomserverQueryEventsByIDPath, internal.MakeInternalAPI("queryEventsByID", func(req *http.Request) util.JSONResponse { var request api.QueryEventsByIDRequest @@ -114,7 +115,7 @@ func (r *RoomserverInternalAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - servMux.Handle( + internalAPIMux.Handle( api.RoomserverQueryMembershipForUserPath, internal.MakeInternalAPI("QueryMembershipForUser", func(req *http.Request) util.JSONResponse { var request api.QueryMembershipForUserRequest @@ -128,7 +129,7 @@ func (r *RoomserverInternalAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - servMux.Handle( + internalAPIMux.Handle( api.RoomserverQueryMembershipsForRoomPath, internal.MakeInternalAPI("queryMembershipsForRoom", func(req *http.Request) util.JSONResponse { var request api.QueryMembershipsForRoomRequest @@ -142,7 +143,7 @@ func (r *RoomserverInternalAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - servMux.Handle( + internalAPIMux.Handle( api.RoomserverQueryInvitesForUserPath, internal.MakeInternalAPI("queryInvitesForUser", func(req *http.Request) util.JSONResponse { var request api.QueryInvitesForUserRequest @@ -156,7 +157,7 @@ func (r *RoomserverInternalAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - servMux.Handle( + internalAPIMux.Handle( api.RoomserverQueryServerAllowedToSeeEventPath, internal.MakeInternalAPI("queryServerAllowedToSeeEvent", func(req *http.Request) util.JSONResponse { var request api.QueryServerAllowedToSeeEventRequest @@ -170,7 +171,7 @@ func (r *RoomserverInternalAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - servMux.Handle( + internalAPIMux.Handle( api.RoomserverQueryMissingEventsPath, internal.MakeInternalAPI("queryMissingEvents", func(req *http.Request) util.JSONResponse { var request api.QueryMissingEventsRequest @@ -184,7 +185,7 @@ func (r *RoomserverInternalAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - servMux.Handle( + internalAPIMux.Handle( api.RoomserverQueryStateAndAuthChainPath, internal.MakeInternalAPI("queryStateAndAuthChain", func(req *http.Request) util.JSONResponse { var request api.QueryStateAndAuthChainRequest @@ -198,7 +199,7 @@ func (r *RoomserverInternalAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - servMux.Handle( + internalAPIMux.Handle( api.RoomserverQueryBackfillPath, internal.MakeInternalAPI("QueryBackfill", func(req *http.Request) util.JSONResponse { var request api.QueryBackfillRequest @@ -212,7 +213,7 @@ func (r *RoomserverInternalAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - servMux.Handle( + internalAPIMux.Handle( api.RoomserverQueryRoomVersionCapabilitiesPath, internal.MakeInternalAPI("QueryRoomVersionCapabilities", func(req *http.Request) util.JSONResponse { var request api.QueryRoomVersionCapabilitiesRequest @@ -226,7 +227,7 @@ func (r *RoomserverInternalAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - servMux.Handle( + internalAPIMux.Handle( api.RoomserverQueryRoomVersionForRoomPath, internal.MakeInternalAPI("QueryRoomVersionForRoom", func(req *http.Request) util.JSONResponse { var request api.QueryRoomVersionForRoomRequest @@ -240,7 +241,7 @@ func (r *RoomserverInternalAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - servMux.Handle( + internalAPIMux.Handle( api.RoomserverSetRoomAliasPath, internal.MakeInternalAPI("setRoomAlias", func(req *http.Request) util.JSONResponse { var request api.SetRoomAliasRequest @@ -254,7 +255,7 @@ func (r *RoomserverInternalAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - servMux.Handle( + internalAPIMux.Handle( api.RoomserverGetRoomIDForAliasPath, internal.MakeInternalAPI("GetRoomIDForAlias", func(req *http.Request) util.JSONResponse { var request api.GetRoomIDForAliasRequest @@ -268,7 +269,7 @@ func (r *RoomserverInternalAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - servMux.Handle( + internalAPIMux.Handle( api.RoomserverGetCreatorIDForAliasPath, internal.MakeInternalAPI("GetCreatorIDForAlias", func(req *http.Request) util.JSONResponse { var request api.GetCreatorIDForAliasRequest @@ -282,7 +283,7 @@ func (r *RoomserverInternalAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - servMux.Handle( + internalAPIMux.Handle( api.RoomserverGetAliasesForRoomIDPath, internal.MakeInternalAPI("getAliasesForRoomID", func(req *http.Request) util.JSONResponse { var request api.GetAliasesForRoomIDRequest @@ -296,7 +297,7 @@ func (r *RoomserverInternalAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - servMux.Handle( + internalAPIMux.Handle( api.RoomserverRemoveRoomAliasPath, internal.MakeInternalAPI("removeRoomAlias", func(req *http.Request) util.JSONResponse { var request api.RemoveRoomAliasRequest diff --git a/roomserver/internal/perform_join.go b/roomserver/internal/perform_join.go index dc786b296..75dfdd19c 100644 --- a/roomserver/internal/perform_join.go +++ b/roomserver/internal/perform_join.go @@ -90,6 +90,12 @@ func (r *RoomserverInternalAPI) performJoinRoomByID( req *api.PerformJoinRequest, res *api.PerformJoinResponse, // nolint:unparam ) error { + // By this point, if req.RoomIDOrAlias contained an alias, then + // it will have been overwritten with a room ID by performJoinRoomByAlias. + // We should now include this in the response so that the CS API can + // return the right room ID. + res.RoomID = req.RoomIDOrAlias + // Get the domain part of the room ID. _, domain, err := gomatrixserverlib.SplitID('!', req.RoomIDOrAlias) if err != nil { @@ -121,20 +127,30 @@ func (r *RoomserverInternalAPI) performJoinRoomByID( return fmt.Errorf("eb.SetContent: %w", err) } - // First work out if this is in response to an existing invite. - // If it is then we avoid the situation where we might think we - // know about a room in the following section but don't know the - // latest state as all of our users have left. + // First work out if this is in response to an existing invite + // from a federated server. If it is then we avoid the situation + // where we might think we know about a room in the following + // section but don't know the latest state as all of our users + // have left. isInvitePending, inviteSender, err := r.isInvitePending(ctx, req.RoomIDOrAlias, req.UserID) if err == nil && isInvitePending { - // Add the server of the person who invited us to the server list, - // as they should be a fairly good bet. - if _, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender); ierr == nil { - req.ServerNames = append(req.ServerNames, inviterDomain) + // Check if there's an invite pending. + _, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender) + if ierr != nil { + return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } - // Perform a federated room join. - return r.performFederatedJoinRoomByID(ctx, req, res) + // Check that the domain isn't ours. If it's local then we don't + // need to do anything as our own copy of the room state will be + // up-to-date. + if inviterDomain != r.Cfg.Matrix.ServerName { + // Add the server of the person who invited us to the server list, + // as they should be a fairly good bet. + req.ServerNames = append(req.ServerNames, inviterDomain) + + // Perform a federated room join. + return r.performFederatedJoinRoomByID(ctx, req, res) + } } // Try to construct an actual join event from the template. diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go index 07250e7ce..82934d50d 100644 --- a/roomserver/roomserver.go +++ b/roomserver/roomserver.go @@ -15,8 +15,6 @@ package roomserver import ( - "net/http" - "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" @@ -51,9 +49,7 @@ func SetupRoomServerComponent( KeyRing: keyRing, } - if base.EnableHTTPAPIs { - internalAPI.SetupHTTP(http.DefaultServeMux) - } + internalAPI.SetupHTTP(base.InternalAPIMux) return &internalAPI } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 1e0232d20..52e6a96b7 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -63,29 +63,80 @@ type Database interface { SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) // Look up a room version from the room NID. GetRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error) + // Stores a matrix room event in the database StoreEvent(ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error) + // Look up the state entries for a list of string event IDs + // Returns an error if the there is an error talking to the database + // Returns a types.MissingEventError if the event IDs aren't in the database. StateEntriesForEventIDs(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) + // Look up the string event state keys for a list of numeric event state keys + // Returns an error if there was a problem talking to the database. EventStateKeys(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) + // Look up the numeric IDs for a list of events. + // Returns an error if there was a problem talking to the database. EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) + // Set the state at an event. FIXME TODO: "at" SetState(ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID) error + // Lookup the event IDs for a batch of event numeric IDs. + // Returns an error if the retrieval went wrong. EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) + // Look up the latest events in a room in preparation for an update. + // The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error. + // Returns the latest events in the room and the last eventID sent to the log along with an updater. + // If this returns an error then no further action is required. GetLatestEventsForUpdate(ctx context.Context, roomNID types.RoomNID) (types.RoomRecentEventsUpdater, error) + // Look up event ID by transaction's info. + // This is used to determine if the room event is processed/processing already. + // Returns an empty string if no such event exists. GetTransactionEventID(ctx context.Context, transactionID string, sessionID int64, userID string) (string, error) + // Look up the numeric ID for the room. + // Returns 0 if the room doesn't exists. + // Returns an error if there was a problem talking to the database. RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) // RoomNIDExcludingStubs is a special variation of RoomNID that will return 0 as if the room // does not exist if the room has no latest events. This can happen when we've received an // invite over federation for a room that we don't know anything else about yet. RoomNIDExcludingStubs(ctx context.Context, roomID string) (types.RoomNID, error) + // Look up event references for the latest events in the room and the current state snapshot. + // Returns the latest events, the current state and the maximum depth of the latest events plus 1. + // Returns an error if there was a problem talking to the database. LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) + // Look up the active invites targeting a user in a room and return the + // numeric state key IDs for the user IDs who sent them. + // Returns an error if there was a problem talking to the database. GetInvitesForUser(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (senderUserIDs []types.EventStateKeyNID, err error) + // Save a given room alias with the room ID it refers to. + // Returns an error if there was a problem talking to the database. SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error + // Look up the room ID a given alias refers to. + // Returns an error if there was a problem talking to the database. GetRoomIDForAlias(ctx context.Context, alias string) (string, error) + // Look up all aliases referring to a given room ID. + // Returns an error if there was a problem talking to the database. GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) + // Get the user ID of the creator of an alias. + // Returns an error if there was a problem talking to the database. GetCreatorIDForAlias(ctx context.Context, alias string) (string, error) + // Remove a given room alias. + // Returns an error if there was a problem talking to the database. RemoveRoomAlias(ctx context.Context, alias string) error + // Build a membership updater for the target user in a room. MembershipUpdater(ctx context.Context, roomID, targetUserID string, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion) (types.MembershipUpdater, error) + // Lookup the membership of a given user in a given room. + // Returns the numeric ID of the latest membership event sent from this user + // in this room, along a boolean set to true if the user is still in this room, + // false if not. + // Returns an error if there was a problem talking to the database. GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error) + // Lookup the membership event numeric IDs for all user that are or have + // been members of a given room. Only lookup events of "join" membership if + // joinOnly is set to true. + // Returns an error if there was a problem talking to the database. GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool) ([]types.EventNID, error) + // EventsFromIDs looks up the Events for a list of event IDs. Does not error if event was + // not found. + // Returns an error if the retrieval went wrong. EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) + // Look up the room version for a given room. GetRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) } diff --git a/roomserver/storage/postgres/event_json_table.go b/roomserver/storage/postgres/event_json_table.go index 661c44721..3bce9f086 100644 --- a/roomserver/storage/postgres/event_json_table.go +++ b/roomserver/storage/postgres/event_json_table.go @@ -21,6 +21,8 @@ import ( "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -58,32 +60,28 @@ type eventJSONStatements struct { bulkSelectEventJSONStmt *sql.Stmt } -func (s *eventJSONStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(eventJSONSchema) +func NewPostgresEventJSONTable(db *sql.DB) (tables.EventJSON, error) { + s := &eventJSONStatements{} + _, err := db.Exec(eventJSONSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertEventJSONStmt, insertEventJSONSQL}, {&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *eventJSONStatements) insertEventJSON( - ctx context.Context, eventNID types.EventNID, eventJSON []byte, +func (s *eventJSONStatements) InsertEventJSON( + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte, ) error { _, err := s.insertEventJSONStmt.ExecContext(ctx, int64(eventNID), eventJSON) return err } -type eventJSONPair struct { - EventNID types.EventNID - EventJSON []byte -} - -func (s *eventJSONStatements) bulkSelectEventJSON( +func (s *eventJSONStatements) BulkSelectEventJSON( ctx context.Context, eventNIDs []types.EventNID, -) ([]eventJSONPair, error) { +) ([]tables.EventJSONPair, error) { rows, err := s.bulkSelectEventJSONStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err @@ -94,7 +92,7 @@ func (s *eventJSONStatements) bulkSelectEventJSON( // because of the unique constraint on event NIDs. // So we can allocate an array of the correct size now. // We might get fewer results than NIDs so we adjust the length of the slice before returning it. - results := make([]eventJSONPair, len(eventNIDs)) + results := make([]tables.EventJSONPair, len(eventNIDs)) i := 0 for ; rows.Next(); i++ { result := &results[i] diff --git a/roomserver/storage/postgres/event_state_keys_table.go b/roomserver/storage/postgres/event_state_keys_table.go index b213e057b..b7e2cb6b7 100644 --- a/roomserver/storage/postgres/event_state_keys_table.go +++ b/roomserver/storage/postgres/event_state_keys_table.go @@ -21,6 +21,8 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -74,20 +76,21 @@ type eventStateKeyStatements struct { bulkSelectEventStateKeyStmt *sql.Stmt } -func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(eventStateKeysSchema) +func NewPostgresEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { + s := &eventStateKeyStatements{} + _, err := db.Exec(eventStateKeysSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL}, {&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL}, {&s.bulkSelectEventStateKeyStmt, bulkSelectEventStateKeySQL}, - }.prepare(db) + }.Prepare(db) } -func (s *eventStateKeyStatements) insertEventStateKeyNID( +func (s *eventStateKeyStatements) InsertEventStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { var eventStateKeyNID int64 @@ -96,7 +99,7 @@ func (s *eventStateKeyStatements) insertEventStateKeyNID( return types.EventStateKeyNID(eventStateKeyNID), err } -func (s *eventStateKeyStatements) selectEventStateKeyNID( +func (s *eventStateKeyStatements) SelectEventStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { var eventStateKeyNID int64 @@ -105,7 +108,7 @@ func (s *eventStateKeyStatements) selectEventStateKeyNID( return types.EventStateKeyNID(eventStateKeyNID), err } -func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( +func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( ctx context.Context, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { rows, err := s.bulkSelectEventStateKeyNIDStmt.QueryContext( @@ -128,7 +131,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( return result, rows.Err() } -func (s *eventStateKeyStatements) bulkSelectEventStateKey( +func (s *eventStateKeyStatements) BulkSelectEventStateKey( ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { nIDs := make(pq.Int64Array, len(eventStateKeyNIDs)) diff --git a/roomserver/storage/postgres/event_types_table.go b/roomserver/storage/postgres/event_types_table.go index 2b0910e71..f8a7bff1b 100644 --- a/roomserver/storage/postgres/event_types_table.go +++ b/roomserver/storage/postgres/event_types_table.go @@ -22,6 +22,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/lib/pq" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -98,36 +100,37 @@ type eventTypeStatements struct { bulkSelectEventTypeNIDStmt *sql.Stmt } -func (s *eventTypeStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(eventTypesSchema) +func NewPostgresEventTypesTable(db *sql.DB) (tables.EventTypes, error) { + s := &eventTypeStatements{} + _, err := db.Exec(eventTypesSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL}, {&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL}, {&s.bulkSelectEventTypeNIDStmt, bulkSelectEventTypeNIDSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *eventTypeStatements) insertEventTypeNID( - ctx context.Context, eventType string, +func (s *eventTypeStatements) InsertEventTypeNID( + ctx context.Context, txn *sql.Tx, eventType string, ) (types.EventTypeNID, error) { var eventTypeNID int64 - err := s.insertEventTypeNIDStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID) + err := txn.Stmt(s.insertEventTypeNIDStmt).QueryRowContext(ctx, eventType).Scan(&eventTypeNID) return types.EventTypeNID(eventTypeNID), err } -func (s *eventTypeStatements) selectEventTypeNID( - ctx context.Context, eventType string, +func (s *eventTypeStatements) SelectEventTypeNID( + ctx context.Context, txn *sql.Tx, eventType string, ) (types.EventTypeNID, error) { var eventTypeNID int64 - err := s.selectEventTypeNIDStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID) + err := txn.Stmt(s.selectEventTypeNIDStmt).QueryRowContext(ctx, eventType).Scan(&eventTypeNID) return types.EventTypeNID(eventTypeNID), err } -func (s *eventTypeStatements) bulkSelectEventTypeNID( +func (s *eventTypeStatements) BulkSelectEventTypeNID( ctx context.Context, eventTypes []string, ) (map[string]types.EventTypeNID, error) { rows, err := s.bulkSelectEventTypeNIDStmt.QueryContext(ctx, pq.StringArray(eventTypes)) diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index c28fa8e6b..d39e9511a 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -22,6 +22,8 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -136,13 +138,14 @@ type eventStatements struct { selectRoomNIDForEventNIDStmt *sql.Stmt } -func (s *eventStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(eventsSchema) +func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { + s := &eventStatements{} + _, err := db.Exec(eventsSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertEventStmt, insertEventSQL}, {&s.selectEventStmt, selectEventSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, @@ -157,11 +160,12 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) { {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, {&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL}, {&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *eventStatements) insertEvent( +func (s *eventStatements) InsertEvent( ctx context.Context, + txn *sql.Tx, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, @@ -179,8 +183,8 @@ func (s *eventStatements) insertEvent( return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err } -func (s *eventStatements) selectEvent( - ctx context.Context, eventID string, +func (s *eventStatements) SelectEvent( + ctx context.Context, txn *sql.Tx, eventID string, ) (types.EventNID, types.StateSnapshotNID, error) { var eventNID int64 var stateNID int64 @@ -190,7 +194,7 @@ func (s *eventStatements) selectEvent( // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError -func (s *eventStatements) bulkSelectStateEventByID( +func (s *eventStatements) BulkSelectStateEventByID( ctx context.Context, eventIDs []string, ) ([]types.StateEntry, error) { rows, err := s.bulkSelectStateEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) @@ -233,7 +237,7 @@ func (s *eventStatements) bulkSelectStateEventByID( // bulkSelectStateAtEventByID lookups the state at a list of events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. -func (s *eventStatements) bulkSelectStateAtEventByID( +func (s *eventStatements) BulkSelectStateAtEventByID( ctx context.Context, eventIDs []string, ) ([]types.StateAtEvent, error) { rows, err := s.bulkSelectStateAtEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) @@ -270,14 +274,14 @@ func (s *eventStatements) bulkSelectStateAtEventByID( return results, nil } -func (s *eventStatements) updateEventState( +func (s *eventStatements) UpdateEventState( ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { _, err := s.updateEventStateStmt.ExecContext(ctx, int64(eventNID), int64(stateNID)) return err } -func (s *eventStatements) selectEventSentToOutput( +func (s *eventStatements) SelectEventSentToOutput( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (sentToOutput bool, err error) { stmt := internal.TxStmt(txn, s.selectEventSentToOutputStmt) @@ -285,13 +289,13 @@ func (s *eventStatements) selectEventSentToOutput( return } -func (s *eventStatements) updateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { +func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { stmt := internal.TxStmt(txn, s.updateEventSentToOutputStmt) _, err := stmt.ExecContext(ctx, int64(eventNID)) return err } -func (s *eventStatements) selectEventID( +func (s *eventStatements) SelectEventID( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (eventID string, err error) { stmt := internal.TxStmt(txn, s.selectEventIDStmt) @@ -299,7 +303,7 @@ func (s *eventStatements) selectEventID( return } -func (s *eventStatements) bulkSelectStateAtEventAndReference( +func (s *eventStatements) BulkSelectStateAtEventAndReference( ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]types.StateAtEventAndReference, error) { stmt := internal.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt) @@ -341,8 +345,8 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference( return results, nil } -func (s *eventStatements) bulkSelectEventReference( - ctx context.Context, eventNIDs []types.EventNID, +func (s *eventStatements) BulkSelectEventReference( + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]gomatrixserverlib.EventReference, error) { rows, err := s.bulkSelectEventReferenceStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { @@ -367,7 +371,7 @@ func (s *eventStatements) bulkSelectEventReference( } // bulkSelectEventID returns a map from numeric event ID to string event ID. -func (s *eventStatements) bulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { +func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { rows, err := s.bulkSelectEventIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err @@ -394,7 +398,7 @@ func (s *eventStatements) bulkSelectEventID(ctx context.Context, eventNIDs []typ // bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) bulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { +func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { rows, err := s.bulkSelectEventNIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err @@ -412,7 +416,7 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, eventIDs []str return results, rows.Err() } -func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []types.EventNID) (int64, error) { +func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) { var result int64 stmt := s.selectMaxEventDepthStmt err := stmt.QueryRowContext(ctx, eventNIDsAsArray(eventNIDs)).Scan(&result) @@ -422,11 +426,10 @@ func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []t return result, nil } -func (s *eventStatements) selectRoomNIDForEventNID( - ctx context.Context, txn *sql.Tx, eventNID types.EventNID, +func (s *eventStatements) SelectRoomNIDForEventNID( + ctx context.Context, eventNID types.EventNID, ) (roomNID types.RoomNID, err error) { - selectStmt := internal.TxStmt(txn, s.selectRoomNIDForEventNIDStmt) - err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&roomNID) + err = s.selectRoomNIDForEventNIDStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&roomNID) return } diff --git a/roomserver/storage/postgres/invite_table.go b/roomserver/storage/postgres/invite_table.go index f0fb919e6..55a5bc7d7 100644 --- a/roomserver/storage/postgres/invite_table.go +++ b/roomserver/storage/postgres/invite_table.go @@ -20,6 +20,8 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -79,20 +81,21 @@ type inviteStatements struct { updateInviteRetiredStmt *sql.Stmt } -func (s *inviteStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(inviteSchema) +func NewPostgresInvitesTable(db *sql.DB) (tables.Invites, error) { + s := &inviteStatements{} + _, err := db.Exec(inviteSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertInviteEventStmt, insertInviteEventSQL}, {&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL}, {&s.updateInviteRetiredStmt, updateInviteRetiredSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *inviteStatements) insertInviteEvent( +func (s *inviteStatements) InsertInviteEvent( ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, @@ -111,7 +114,7 @@ func (s *inviteStatements) insertInviteEvent( return count != 0, nil } -func (s *inviteStatements) updateInviteRetired( +func (s *inviteStatements) UpdateInviteRetired( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) ([]string, error) { @@ -133,8 +136,8 @@ func (s *inviteStatements) updateInviteRetired( return eventIDs, rows.Err() } -// selectInviteActiveForUserInRoom returns a list of sender state key NIDs -func (s *inviteStatements) selectInviteActiveForUserInRoom( +// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs +func (s *inviteStatements) SelectInviteActiveForUserInRoom( ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, ) ([]types.EventStateKeyNID, error) { diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index f290a05f9..c635ce28a 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -20,17 +20,11 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) -type membershipState int64 - -const ( - membershipStateLeaveOrBan membershipState = 1 - membershipStateInvite membershipState = 2 - membershipStateJoin membershipState = 3 -) - const membershipSchema = ` -- The membership table is used to coordinate updates between the invite table -- and the room state tables. @@ -115,13 +109,14 @@ type membershipStatements struct { updateMembershipStmt *sql.Stmt } -func (s *membershipStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(membershipSchema) +func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { + s := &membershipStatements{} + _, err := db.Exec(membershipSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertMembershipStmt, insertMembershipSQL}, {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, @@ -130,10 +125,10 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *membershipStatements) insertMembership( +func (s *membershipStatements) InsertMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool, @@ -143,27 +138,27 @@ func (s *membershipStatements) insertMembership( return err } -func (s *membershipStatements) selectMembershipForUpdate( +func (s *membershipStatements) SelectMembershipForUpdate( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, -) (membership membershipState, err error) { +) (membership tables.MembershipState, err error) { err = internal.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext( ctx, roomNID, targetUserNID, ).Scan(&membership) return } -func (s *membershipStatements) selectMembershipFromRoomAndTarget( +func (s *membershipStatements) SelectMembershipFromRoomAndTarget( ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, -) (eventNID types.EventNID, membership membershipState, err error) { +) (eventNID types.EventNID, membership tables.MembershipState, err error) { err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( ctx, roomNID, targetUserNID, ).Scan(&membership, &eventNID) return } -func (s *membershipStatements) selectMembershipsFromRoom( +func (s *membershipStatements) SelectMembershipsFromRoom( ctx context.Context, roomNID types.RoomNID, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var stmt *sql.Stmt @@ -188,9 +183,9 @@ func (s *membershipStatements) selectMembershipsFromRoom( return eventNIDs, rows.Err() } -func (s *membershipStatements) selectMembershipsFromRoomAndMembership( +func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( ctx context.Context, - roomNID types.RoomNID, membership membershipState, localOnly bool, + roomNID types.RoomNID, membership tables.MembershipState, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var rows *sql.Rows var stmt *sql.Stmt @@ -215,10 +210,10 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership( return eventNIDs, rows.Err() } -func (s *membershipStatements) updateMembership( +func (s *membershipStatements) UpdateMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, - senderUserNID types.EventStateKeyNID, membership membershipState, + senderUserNID types.EventStateKeyNID, membership tables.MembershipState, eventNID types.EventNID, ) error { _, err := internal.TxStmt(txn, s.updateMembershipStmt).ExecContext( diff --git a/roomserver/storage/postgres/prepare.go b/roomserver/storage/postgres/prepare.go deleted file mode 100644 index 70b6e5161..000000000 --- a/roomserver/storage/postgres/prepare.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "database/sql" -) - -// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement. -type statementList []struct { - statement **sql.Stmt - sql string -} - -// prepare the SQL for each statement in the list and assign the result to the prepared statement. -func (s statementList) prepare(db *sql.DB) (err error) { - for _, statement := range s { - if *statement.statement, err = db.Prepare(statement.sql); err != nil { - return - } - } - return -} diff --git a/roomserver/storage/postgres/previous_events_table.go b/roomserver/storage/postgres/previous_events_table.go index e3ad5dc80..ff0306967 100644 --- a/roomserver/storage/postgres/previous_events_table.go +++ b/roomserver/storage/postgres/previous_events_table.go @@ -20,6 +20,8 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -63,19 +65,20 @@ type previousEventStatements struct { selectPreviousEventExistsStmt *sql.Stmt } -func (s *previousEventStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(previousEventSchema) +func NewPostgresPreviousEventsTable(db *sql.DB) (tables.PreviousEvents, error) { + s := &previousEventStatements{} + _, err := db.Exec(previousEventSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertPreviousEventStmt, insertPreviousEventSQL}, {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *previousEventStatements) insertPreviousEvent( +func (s *previousEventStatements) InsertPreviousEvent( ctx context.Context, txn *sql.Tx, previousEventID string, @@ -91,7 +94,7 @@ func (s *previousEventStatements) insertPreviousEvent( // Check if the event reference exists // Returns sql.ErrNoRows if the event reference doesn't exist. -func (s *previousEventStatements) selectPreviousEventExists( +func (s *previousEventStatements) SelectPreviousEventExists( ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte, ) error { var ok int64 diff --git a/roomserver/storage/postgres/room_aliases_table.go b/roomserver/storage/postgres/room_aliases_table.go index c77edd0e0..85042c54f 100644 --- a/roomserver/storage/postgres/room_aliases_table.go +++ b/roomserver/storage/postgres/room_aliases_table.go @@ -20,6 +20,8 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" ) const roomAliasesSchema = ` @@ -59,28 +61,29 @@ type roomAliasesStatements struct { deleteRoomAliasStmt *sql.Stmt } -func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(roomAliasesSchema) +func NewPostgresRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { + s := &roomAliasesStatements{} + _, err := db.Exec(roomAliasesSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertRoomAliasStmt, insertRoomAliasSQL}, {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL}, {&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL}, {&s.selectCreatorIDFromAliasStmt, selectCreatorIDFromAliasSQL}, {&s.deleteRoomAliasStmt, deleteRoomAliasSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *roomAliasesStatements) insertRoomAlias( +func (s *roomAliasesStatements) InsertRoomAlias( ctx context.Context, alias string, roomID string, creatorUserID string, ) (err error) { _, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID) return } -func (s *roomAliasesStatements) selectRoomIDFromAlias( +func (s *roomAliasesStatements) SelectRoomIDFromAlias( ctx context.Context, alias string, ) (roomID string, err error) { err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) @@ -90,7 +93,7 @@ func (s *roomAliasesStatements) selectRoomIDFromAlias( return } -func (s *roomAliasesStatements) selectAliasesFromRoomID( +func (s *roomAliasesStatements) SelectAliasesFromRoomID( ctx context.Context, roomID string, ) ([]string, error) { rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) @@ -111,7 +114,7 @@ func (s *roomAliasesStatements) selectAliasesFromRoomID( return aliases, rows.Err() } -func (s *roomAliasesStatements) selectCreatorIDFromAlias( +func (s *roomAliasesStatements) SelectCreatorIDFromAlias( ctx context.Context, alias string, ) (creatorID string, err error) { err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID) @@ -121,7 +124,7 @@ func (s *roomAliasesStatements) selectCreatorIDFromAlias( return } -func (s *roomAliasesStatements) deleteRoomAlias( +func (s *roomAliasesStatements) DeleteRoomAlias( ctx context.Context, alias string, ) (err error) { _, err = s.deleteRoomAliasStmt.ExecContext(ctx, alias) diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index fc64489de..5a353ed1c 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -22,6 +22,8 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -82,12 +84,13 @@ type roomStatements struct { selectRoomVersionForRoomNIDStmt *sql.Stmt } -func (s *roomStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(roomsSchema) +func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) { + s := &roomStatements{} + _, err := db.Exec(roomsSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertRoomNIDStmt, insertRoomNIDSQL}, {&s.selectRoomNIDStmt, selectRoomNIDSQL}, {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, @@ -95,10 +98,10 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, {&s.selectRoomVersionForRoomIDStmt, selectRoomVersionForRoomIDSQL}, {&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *roomStatements) insertRoomNID( +func (s *roomStatements) InsertRoomNID( ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion, ) (types.RoomNID, error) { @@ -108,7 +111,7 @@ func (s *roomStatements) insertRoomNID( return types.RoomNID(roomNID), err } -func (s *roomStatements) selectRoomNID( +func (s *roomStatements) SelectRoomNID( ctx context.Context, txn *sql.Tx, roomID string, ) (types.RoomNID, error) { var roomNID int64 @@ -117,8 +120,8 @@ func (s *roomStatements) selectRoomNID( return types.RoomNID(roomNID), err } -func (s *roomStatements) selectLatestEventNIDs( - ctx context.Context, roomNID types.RoomNID, +func (s *roomStatements) SelectLatestEventNIDs( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.StateSnapshotNID, error) { var nids pq.Int64Array var stateSnapshotNID int64 @@ -134,7 +137,7 @@ func (s *roomStatements) selectLatestEventNIDs( return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil } -func (s *roomStatements) selectLatestEventsNIDsForUpdate( +func (s *roomStatements) SelectLatestEventsNIDsForUpdate( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) { var nids pq.Int64Array @@ -152,7 +155,7 @@ func (s *roomStatements) selectLatestEventsNIDsForUpdate( return eventNIDs, types.EventNID(lastEventSentNID), types.StateSnapshotNID(stateSnapshotNID), nil } -func (s *roomStatements) updateLatestEventNIDs( +func (s *roomStatements) UpdateLatestEventNIDs( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, @@ -171,7 +174,7 @@ func (s *roomStatements) updateLatestEventNIDs( return err } -func (s *roomStatements) selectRoomVersionForRoomID( +func (s *roomStatements) SelectRoomVersionForRoomID( ctx context.Context, txn *sql.Tx, roomID string, ) (gomatrixserverlib.RoomVersion, error) { var roomVersion gomatrixserverlib.RoomVersion @@ -183,12 +186,11 @@ func (s *roomStatements) selectRoomVersionForRoomID( return roomVersion, err } -func (s *roomStatements) selectRoomVersionForRoomNID( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +func (s *roomStatements) SelectRoomVersionForRoomNID( + ctx context.Context, roomNID types.RoomNID, ) (gomatrixserverlib.RoomVersion, error) { var roomVersion gomatrixserverlib.RoomVersion - stmt := internal.TxStmt(txn, s.selectRoomVersionForRoomNIDStmt) - err := stmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion) + err := s.selectRoomVersionForRoomNIDStmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion) if err == sql.ErrNoRows { return roomVersion, errors.New("room not found") } diff --git a/roomserver/storage/postgres/sql.go b/roomserver/storage/postgres/sql.go deleted file mode 100644 index 5956886ce..000000000 --- a/roomserver/storage/postgres/sql.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "database/sql" -) - -type statements struct { - eventTypeStatements - eventStateKeyStatements - roomStatements - eventStatements - eventJSONStatements - stateSnapshotStatements - stateBlockStatements - previousEventStatements - roomAliasesStatements - inviteStatements - membershipStatements - transactionStatements -} - -func (s *statements) prepare(db *sql.DB) error { - var err error - - for _, prepare := range []func(db *sql.DB) error{ - s.eventTypeStatements.prepare, - s.eventStateKeyStatements.prepare, - s.roomStatements.prepare, - s.eventStatements.prepare, - s.eventJSONStatements.prepare, - s.stateSnapshotStatements.prepare, - s.stateBlockStatements.prepare, - s.previousEventStatements.prepare, - s.roomAliasesStatements.prepare, - s.inviteStatements.prepare, - s.membershipStatements.prepare, - s.transactionStatements.prepare, - } { - if err = prepare(db); err != nil { - return err - } - } - - return nil -} diff --git a/roomserver/storage/postgres/state_block_table.go b/roomserver/storage/postgres/state_block_table.go index 38334fa9e..dbe21d989 100644 --- a/roomserver/storage/postgres/state_block_table.go +++ b/roomserver/storage/postgres/state_block_table.go @@ -24,6 +24,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/lib/pq" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" ) @@ -87,25 +89,30 @@ type stateBlockStatements struct { bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt } -func (s *stateBlockStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(stateDataSchema) +func NewPostgresStateBlockTable(db *sql.DB) (tables.StateBlock, error) { + s := &stateBlockStatements{} + _, err := db.Exec(stateDataSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertStateDataStmt, insertStateDataSQL}, {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL}, {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, {&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *stateBlockStatements) bulkInsertStateData( +func (s *stateBlockStatements) BulkInsertStateData( ctx context.Context, - stateBlockNID types.StateBlockNID, + txn *sql.Tx, entries []types.StateEntry, -) error { +) (types.StateBlockNID, error) { + stateBlockNID, err := s.selectNextStateBlockNID(ctx) + if err != nil { + return 0, err + } for _, entry := range entries { _, err := s.insertStateDataStmt.ExecContext( ctx, @@ -115,10 +122,10 @@ func (s *stateBlockStatements) bulkInsertStateData( int64(entry.EventNID), ) if err != nil { - return err + return 0, err } } - return nil + return stateBlockNID, nil } func (s *stateBlockStatements) selectNextStateBlockNID( @@ -129,7 +136,7 @@ func (s *stateBlockStatements) selectNextStateBlockNID( return types.StateBlockNID(stateBlockNID), err } -func (s *stateBlockStatements) bulkSelectStateBlockEntries( +func (s *stateBlockStatements) BulkSelectStateBlockEntries( ctx context.Context, stateBlockNIDs []types.StateBlockNID, ) ([]types.StateEntryList, error) { nids := make([]int64, len(stateBlockNIDs)) @@ -180,7 +187,7 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries( return results, err } -func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries( +func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries( ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple, diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go index a1f26e228..0f8f1c51e 100644 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -21,6 +21,8 @@ import ( "fmt" "github.com/lib/pq" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -64,30 +66,31 @@ type stateSnapshotStatements struct { bulkSelectStateBlockNIDsStmt *sql.Stmt } -func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(stateSnapshotSchema) +func NewPostgresStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { + s := &stateSnapshotStatements{} + _, err := db.Exec(stateSnapshotSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertStateStmt, insertStateSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *stateSnapshotStatements) insertState( - ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, +func (s *stateSnapshotStatements) InsertState( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, ) (stateNID types.StateSnapshotNID, err error) { nids := make([]int64, len(stateBlockNIDs)) for i := range stateBlockNIDs { nids[i] = int64(stateBlockNIDs[i]) } - err = s.insertStateStmt.QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID) + err = txn.Stmt(s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID) return } -func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs( +func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( ctx context.Context, stateNIDs []types.StateSnapshotNID, ) ([]types.StateBlockNIDList, error) { nids := make([]int64, len(stateNIDs)) diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 4d1d603e3..971f2b9e6 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -9,799 +9,99 @@ // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implie // See the License for the specific language governing permissions and // limitations under the License. package postgres import ( - "context" "database/sql" - "encoding/json" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" // Import the postgres database driver. _ "github.com/lib/pq" - "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/roomserver/storage/shared" ) // A Database is used to store room events and stream offsets. type Database struct { - statements statements - db *sql.DB + shared.Database } // Open a postgres database. +// nolint: gocyclo func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, error) { var d Database + var db *sql.DB var err error - if d.db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil { + if db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil { return nil, err } - if err = d.statements.prepare(d.db); err != nil { + eventStateKeys, err := NewPostgresEventStateKeysTable(db) + if err != nil { return nil, err } + eventTypes, err := NewPostgresEventTypesTable(db) + if err != nil { + return nil, err + } + eventJSON, err := NewPostgresEventJSONTable(db) + if err != nil { + return nil, err + } + events, err := NewPostgresEventsTable(db) + if err != nil { + return nil, err + } + rooms, err := NewPostgresRoomsTable(db) + if err != nil { + return nil, err + } + transactions, err := NewPostgresTransactionsTable(db) + if err != nil { + return nil, err + } + stateBlock, err := NewPostgresStateBlockTable(db) + if err != nil { + return nil, err + } + stateSnapshot, err := NewPostgresStateSnapshotTable(db) + if err != nil { + return nil, err + } + roomAliases, err := NewPostgresRoomAliasesTable(db) + if err != nil { + return nil, err + } + prevEvents, err := NewPostgresPreviousEventsTable(db) + if err != nil { + return nil, err + } + invites, err := NewPostgresInvitesTable(db) + if err != nil { + return nil, err + } + membership, err := NewPostgresMembershipTable(db) + if err != nil { + return nil, err + } + d.Database = shared.Database{ + DB: db, + EventTypesTable: eventTypes, + EventStateKeysTable: eventStateKeys, + EventJSONTable: eventJSON, + EventsTable: events, + RoomsTable: rooms, + TransactionsTable: transactions, + StateBlockTable: stateBlock, + StateSnapshotTable: stateSnapshot, + PrevEventsTable: prevEvents, + RoomAliasesTable: roomAliases, + InvitesTable: invites, + MembershipTable: membership, + } return &d, nil } - -// StoreEvent implements input.EventDatabase -func (d *Database) StoreEvent( - ctx context.Context, event gomatrixserverlib.Event, - txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, -) (types.RoomNID, types.StateAtEvent, error) { - var ( - roomNID types.RoomNID - eventTypeNID types.EventTypeNID - eventStateKeyNID types.EventStateKeyNID - eventNID types.EventNID - stateNID types.StateSnapshotNID - err error - ) - - if txnAndSessionID != nil { - if err = d.statements.insertTransaction( - ctx, txnAndSessionID.TransactionID, - txnAndSessionID.SessionID, event.Sender(), event.EventID(), - ); err != nil { - return 0, types.StateAtEvent{}, err - } - } - - // TODO: Here we should aim to have two different code paths for new rooms - // vs existing ones. - - // Get the default room version. If the client doesn't supply a room_version - // then we will use our configured default to create the room. - // https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom - // Note that the below logic depends on the m.room.create event being the - // first event that is persisted to the database when creating or joining a - // room. - var roomVersion gomatrixserverlib.RoomVersion - if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil { - return 0, types.StateAtEvent{}, err - } - - if roomNID, err = d.assignRoomNID(ctx, nil, event.RoomID(), roomVersion); err != nil { - return 0, types.StateAtEvent{}, err - } - - if eventTypeNID, err = d.assignEventTypeNID(ctx, event.Type()); err != nil { - return 0, types.StateAtEvent{}, err - } - - eventStateKey := event.StateKey() - // Assigned a numeric ID for the state_key if there is one present. - // Otherwise set the numeric ID for the state_key to 0. - if eventStateKey != nil { - if eventStateKeyNID, err = d.assignStateKeyNID(ctx, nil, *eventStateKey); err != nil { - return 0, types.StateAtEvent{}, err - } - } - - if eventNID, stateNID, err = d.statements.insertEvent( - ctx, - roomNID, - eventTypeNID, - eventStateKeyNID, - event.EventID(), - event.EventReference().EventSHA256, - authEventNIDs, - event.Depth(), - ); err != nil { - if err == sql.ErrNoRows { - // We've already inserted the event so select the numeric event ID - eventNID, stateNID, err = d.statements.selectEvent(ctx, event.EventID()) - } - if err != nil { - return 0, types.StateAtEvent{}, err - } - } - - if err = d.statements.insertEventJSON(ctx, eventNID, event.JSON()); err != nil { - return 0, types.StateAtEvent{}, err - } - - return roomNID, types.StateAtEvent{ - BeforeStateSnapshotNID: stateNID, - StateEntry: types.StateEntry{ - StateKeyTuple: types.StateKeyTuple{ - EventTypeNID: eventTypeNID, - EventStateKeyNID: eventStateKeyNID, - }, - EventNID: eventNID, - }, - }, nil -} - -func extractRoomVersionFromCreateEvent(event gomatrixserverlib.Event) ( - gomatrixserverlib.RoomVersion, error, -) { - var err error - var roomVersion gomatrixserverlib.RoomVersion - // Look for m.room.create events. - if event.Type() != gomatrixserverlib.MRoomCreate { - return gomatrixserverlib.RoomVersion(""), nil - } - roomVersion = gomatrixserverlib.RoomVersionV1 - var createContent gomatrixserverlib.CreateContent - // The m.room.create event contains an optional "room_version" key in - // the event content, so we need to unmarshal that first. - if err = json.Unmarshal(event.Content(), &createContent); err != nil { - return gomatrixserverlib.RoomVersion(""), err - } - // A room version was specified in the event content? - if createContent.RoomVersion != nil { - roomVersion = gomatrixserverlib.RoomVersion(*createContent.RoomVersion) - } - return roomVersion, err -} - -func (d *Database) assignRoomNID( - ctx context.Context, txn *sql.Tx, - roomID string, roomVersion gomatrixserverlib.RoomVersion, -) (types.RoomNID, error) { - // Check if we already have a numeric ID in the database. - roomNID, err := d.statements.selectRoomNID(ctx, txn, roomID) - if err == sql.ErrNoRows { - // We don't have a numeric ID so insert one into the database. - roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID, roomVersion) - if err == sql.ErrNoRows { - // We raced with another insert so run the select again. - roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) - } - } - return roomNID, err -} - -func (d *Database) assignEventTypeNID( - ctx context.Context, eventType string, -) (types.EventTypeNID, error) { - // Check if we already have a numeric ID in the database. - eventTypeNID, err := d.statements.selectEventTypeNID(ctx, eventType) - if err == sql.ErrNoRows { - // We don't have a numeric ID so insert one into the database. - eventTypeNID, err = d.statements.insertEventTypeNID(ctx, eventType) - if err == sql.ErrNoRows { - // We raced with another insert so run the select again. - eventTypeNID, err = d.statements.selectEventTypeNID(ctx, eventType) - } - } - return eventTypeNID, err -} - -func (d *Database) assignStateKeyNID( - ctx context.Context, txn *sql.Tx, eventStateKey string, -) (types.EventStateKeyNID, error) { - // Check if we already have a numeric ID in the database. - eventStateKeyNID, err := d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) - if err == sql.ErrNoRows { - // We don't have a numeric ID so insert one into the database. - eventStateKeyNID, err = d.statements.insertEventStateKeyNID(ctx, txn, eventStateKey) - if err == sql.ErrNoRows { - // We raced with another insert so run the select again. - eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) - } - } - return eventStateKeyNID, err -} - -// StateEntriesForEventIDs implements input.EventDatabase -func (d *Database) StateEntriesForEventIDs( - ctx context.Context, eventIDs []string, -) ([]types.StateEntry, error) { - return d.statements.bulkSelectStateEventByID(ctx, eventIDs) -} - -// EventTypeNIDs implements state.RoomStateDatabase -func (d *Database) EventTypeNIDs( - ctx context.Context, eventTypes []string, -) (map[string]types.EventTypeNID, error) { - return d.statements.bulkSelectEventTypeNID(ctx, eventTypes) -} - -// EventStateKeyNIDs implements state.RoomStateDatabase -func (d *Database) EventStateKeyNIDs( - ctx context.Context, eventStateKeys []string, -) (map[string]types.EventStateKeyNID, error) { - return d.statements.bulkSelectEventStateKeyNID(ctx, eventStateKeys) -} - -// EventStateKeys implements query.RoomserverQueryAPIDatabase -func (d *Database) EventStateKeys( - ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, -) (map[types.EventStateKeyNID]string, error) { - return d.statements.bulkSelectEventStateKey(ctx, eventStateKeyNIDs) -} - -// EventNIDs implements query.RoomserverQueryAPIDatabase -func (d *Database) EventNIDs( - ctx context.Context, eventIDs []string, -) (map[string]types.EventNID, error) { - return d.statements.bulkSelectEventNID(ctx, eventIDs) -} - -// Events implements input.EventDatabase -func (d *Database) Events( - ctx context.Context, eventNIDs []types.EventNID, -) ([]types.Event, error) { - eventJSONs, err := d.statements.bulkSelectEventJSON(ctx, eventNIDs) - if err != nil { - return nil, err - } - results := make([]types.Event, len(eventJSONs)) - for i, eventJSON := range eventJSONs { - var roomNID types.RoomNID - var roomVersion gomatrixserverlib.RoomVersion - result := &results[i] - result.EventNID = eventJSON.EventNID - roomNID, err = d.statements.selectRoomNIDForEventNID(ctx, nil, eventJSON.EventNID) - if err != nil { - return nil, err - } - roomVersion, err = d.statements.selectRoomVersionForRoomNID(ctx, nil, roomNID) - if err != nil { - return nil, err - } - result.Event, err = gomatrixserverlib.NewEventFromTrustedJSON( - eventJSON.EventJSON, false, roomVersion, - ) - if err != nil { - return nil, err - } - } - return results, nil -} - -// AddState implements input.EventDatabase -func (d *Database) AddState( - ctx context.Context, - roomNID types.RoomNID, - stateBlockNIDs []types.StateBlockNID, - state []types.StateEntry, -) (types.StateSnapshotNID, error) { - if len(state) > 0 { - stateBlockNID, err := d.statements.selectNextStateBlockNID(ctx) - if err != nil { - return 0, err - } - if err = d.statements.bulkInsertStateData(ctx, stateBlockNID, state); err != nil { - return 0, err - } - stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) - } - - return d.statements.insertState(ctx, roomNID, stateBlockNIDs) -} - -// SetState implements input.EventDatabase -func (d *Database) SetState( - ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, -) error { - return d.statements.updateEventState(ctx, eventNID, stateNID) -} - -// StateAtEventIDs implements input.EventDatabase -func (d *Database) StateAtEventIDs( - ctx context.Context, eventIDs []string, -) ([]types.StateAtEvent, error) { - return d.statements.bulkSelectStateAtEventByID(ctx, eventIDs) -} - -// StateBlockNIDs implements state.RoomStateDatabase -func (d *Database) StateBlockNIDs( - ctx context.Context, stateNIDs []types.StateSnapshotNID, -) ([]types.StateBlockNIDList, error) { - return d.statements.bulkSelectStateBlockNIDs(ctx, stateNIDs) -} - -// StateEntries implements state.RoomStateDatabase -func (d *Database) StateEntries( - ctx context.Context, stateBlockNIDs []types.StateBlockNID, -) ([]types.StateEntryList, error) { - return d.statements.bulkSelectStateBlockEntries(ctx, stateBlockNIDs) -} - -// SnapshotNIDFromEventID implements state.RoomStateDatabase -func (d *Database) SnapshotNIDFromEventID( - ctx context.Context, eventID string, -) (types.StateSnapshotNID, error) { - _, stateNID, err := d.statements.selectEvent(ctx, eventID) - return stateNID, err -} - -// EventIDs implements input.RoomEventDatabase -func (d *Database) EventIDs( - ctx context.Context, eventNIDs []types.EventNID, -) (map[types.EventNID]string, error) { - return d.statements.bulkSelectEventID(ctx, eventNIDs) -} - -// GetLatestEventsForUpdate implements input.EventDatabase -func (d *Database) GetLatestEventsForUpdate( - ctx context.Context, roomNID types.RoomNID, -) (types.RoomRecentEventsUpdater, error) { - txn, err := d.db.Begin() - if err != nil { - return nil, err - } - eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := - d.statements.selectLatestEventsNIDsForUpdate(ctx, txn, roomNID) - if err != nil { - txn.Rollback() // nolint: errcheck - return nil, err - } - stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) - if err != nil { - txn.Rollback() // nolint: errcheck - return nil, err - } - var lastEventIDSent string - if lastEventNIDSent != 0 { - lastEventIDSent, err = d.statements.selectEventID(ctx, txn, lastEventNIDSent) - if err != nil { - txn.Rollback() // nolint: errcheck - return nil, err - } - } - return &roomRecentEventsUpdater{ - transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, - }, nil -} - -// GetTransactionEventID implements input.EventDatabase -func (d *Database) GetTransactionEventID( - ctx context.Context, transactionID string, - sessionID int64, userID string, -) (string, error) { - eventID, err := d.statements.selectTransactionEventID(ctx, transactionID, sessionID, userID) - if err == sql.ErrNoRows { - return "", nil - } - return eventID, err -} - -type roomRecentEventsUpdater struct { - transaction - d *Database - roomNID types.RoomNID - latestEvents []types.StateAtEventAndReference - lastEventIDSent string - currentStateSnapshotNID types.StateSnapshotNID -} - -// RoomVersion implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) { - version, _ = u.d.GetRoomVersionForRoomNID(u.ctx, u.roomNID) - return -} - -// LatestEvents implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) LatestEvents() []types.StateAtEventAndReference { - return u.latestEvents -} - -// LastEventIDSent implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) LastEventIDSent() string { - return u.lastEventIDSent -} - -// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { - return u.currentStateSnapshotNID -} - -// StorePreviousEvents implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { - for _, ref := range previousEventReferences { - if err := u.d.statements.insertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { - return err - } - } - return nil -} - -// IsReferenced implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { - err := u.d.statements.selectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) - if err == nil { - return true, nil - } - if err == sql.ErrNoRows { - return false, nil - } - return false, err -} - -// SetLatestEvents implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) SetLatestEvents( - roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, - currentStateSnapshotNID types.StateSnapshotNID, -) error { - eventNIDs := make([]types.EventNID, len(latest)) - for i := range latest { - eventNIDs[i] = latest[i].EventNID - } - return u.d.statements.updateLatestEventNIDs(u.ctx, u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) -} - -// HasEventBeenSent implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { - return u.d.statements.selectEventSentToOutput(u.ctx, u.txn, eventNID) -} - -// MarkEventAsSent implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { - return u.d.statements.updateEventSentToOutput(u.ctx, u.txn, eventNID) -} - -func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (types.MembershipUpdater, error) { - return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID, targetLocal) -} - -// RoomNID implements query.RoomserverQueryAPIDB -func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) { - roomNID, err := d.statements.selectRoomNID(ctx, nil, roomID) - if err == sql.ErrNoRows { - return 0, nil - } - return roomNID, err -} - -// RoomNIDExcludingStubs implements query.RoomserverQueryAPIDB -func (d *Database) RoomNIDExcludingStubs(ctx context.Context, roomID string) (roomNID types.RoomNID, err error) { - roomNID, err = d.RoomNID(ctx, roomID) - if err != nil { - return - } - latestEvents, _, err := d.statements.selectLatestEventNIDs(ctx, roomNID) - if err != nil { - return - } - if len(latestEvents) == 0 { - roomNID = 0 - return - } - return -} - -// LatestEventIDs implements query.RoomserverQueryAPIDatabase -func (d *Database) LatestEventIDs( - ctx context.Context, roomNID types.RoomNID, -) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) { - eventNIDs, currentStateSnapshotNID, err := d.statements.selectLatestEventNIDs(ctx, roomNID) - if err != nil { - return nil, 0, 0, err - } - references, err := d.statements.bulkSelectEventReference(ctx, eventNIDs) - if err != nil { - return nil, 0, 0, err - } - depth, err := d.statements.selectMaxEventDepth(ctx, eventNIDs) - if err != nil { - return nil, 0, 0, err - } - return references, currentStateSnapshotNID, depth, nil -} - -// GetInvitesForUser implements query.RoomserverQueryAPIDatabase -func (d *Database) GetInvitesForUser( - ctx context.Context, - roomNID types.RoomNID, - targetUserNID types.EventStateKeyNID, -) (senderUserIDs []types.EventStateKeyNID, err error) { - return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) -} - -// SetRoomAlias implements alias.RoomserverAliasAPIDB -func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { - return d.statements.insertRoomAlias(ctx, alias, roomID, creatorUserID) -} - -// GetRoomIDForAlias implements alias.RoomserverAliasAPIDB -func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { - return d.statements.selectRoomIDFromAlias(ctx, alias) -} - -// GetAliasesForRoomID implements alias.RoomserverAliasAPIDB -func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) { - return d.statements.selectAliasesFromRoomID(ctx, roomID) -} - -// GetCreatorIDForAlias implements alias.RoomserverAliasAPIDB -func (d *Database) GetCreatorIDForAlias( - ctx context.Context, alias string, -) (string, error) { - return d.statements.selectCreatorIDFromAlias(ctx, alias) -} - -// RemoveRoomAlias implements alias.RoomserverAliasAPIDB -func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { - return d.statements.deleteRoomAlias(ctx, alias) -} - -// StateEntriesForTuples implements state.RoomStateDatabase -func (d *Database) StateEntriesForTuples( - ctx context.Context, - stateBlockNIDs []types.StateBlockNID, - stateKeyTuples []types.StateKeyTuple, -) ([]types.StateEntryList, error) { - return d.statements.bulkSelectFilteredStateBlockEntries( - ctx, stateBlockNIDs, stateKeyTuples, - ) -} - -// MembershipUpdater implements input.RoomEventDatabase -func (d *Database) MembershipUpdater( - ctx context.Context, roomID, targetUserID string, - targetLocal bool, roomVersion gomatrixserverlib.RoomVersion, -) (types.MembershipUpdater, error) { - txn, err := d.db.Begin() - if err != nil { - return nil, err - } - succeeded := false - defer func() { - if !succeeded { - txn.Rollback() // nolint: errcheck - } - }() - - roomNID, err := d.assignRoomNID(ctx, txn, roomID, roomVersion) - if err != nil { - return nil, err - } - - targetUserNID, err := d.assignStateKeyNID(ctx, txn, targetUserID) - if err != nil { - return nil, err - } - - updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID, targetLocal) - if err != nil { - return nil, err - } - - succeeded = true - return updater, nil -} - -type membershipUpdater struct { - transaction - d *Database - roomNID types.RoomNID - targetUserNID types.EventStateKeyNID - membership membershipState -} - -func (d *Database) membershipUpdaterTxn( - ctx context.Context, - txn *sql.Tx, - roomNID types.RoomNID, - targetUserNID types.EventStateKeyNID, - targetLocal bool, -) (types.MembershipUpdater, error) { - - if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { - return nil, err - } - - membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) - if err != nil { - return nil, err - } - - return &membershipUpdater{ - transaction{ctx, txn}, d, roomNID, targetUserNID, membership, - }, nil -} - -// IsInvite implements types.MembershipUpdater -func (u *membershipUpdater) IsInvite() bool { - return u.membership == membershipStateInvite -} - -// IsJoin implements types.MembershipUpdater -func (u *membershipUpdater) IsJoin() bool { - return u.membership == membershipStateJoin -} - -// IsLeave implements types.MembershipUpdater -func (u *membershipUpdater) IsLeave() bool { - return u.membership == membershipStateLeaveOrBan -} - -// SetToInvite implements types.MembershipUpdater -func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) - if err != nil { - return false, err - } - inserted, err := u.d.statements.insertInviteEvent( - u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), - ) - if err != nil { - return false, err - } - if u.membership != membershipStateInvite { - if err = u.d.statements.updateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0, - ); err != nil { - return false, err - } - } - return inserted, nil -} - -// SetToJoin implements types.MembershipUpdater -func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) { - var inviteEventIDs []string - - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) - if err != nil { - return nil, err - } - - // If this is a join event update, there is no invite to update - if !isUpdate { - inviteEventIDs, err = u.d.statements.updateInviteRetired( - u.ctx, u.txn, u.roomNID, u.targetUserNID, - ) - if err != nil { - return nil, err - } - } - - // Look up the NID of the new join event - nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) - if err != nil { - return nil, err - } - - if u.membership != membershipStateJoin || isUpdate { - if err = u.d.statements.updateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, - membershipStateJoin, nIDs[eventID], - ); err != nil { - return nil, err - } - } - - return inviteEventIDs, nil -} - -// SetToLeave implements types.MembershipUpdater -func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) - if err != nil { - return nil, err - } - inviteEventIDs, err := u.d.statements.updateInviteRetired( - u.ctx, u.txn, u.roomNID, u.targetUserNID, - ) - if err != nil { - return nil, err - } - - // Look up the NID of the new leave event - nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) - if err != nil { - return nil, err - } - - if u.membership != membershipStateLeaveOrBan { - if err = u.d.statements.updateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, - membershipStateLeaveOrBan, nIDs[eventID], - ); err != nil { - return nil, err - } - } - return inviteEventIDs, nil -} - -// GetMembership implements query.RoomserverQueryAPIDB -func (d *Database) GetMembership( - ctx context.Context, roomNID types.RoomNID, requestSenderUserID string, -) (membershipEventNID types.EventNID, stillInRoom bool, err error) { - requestSenderUserNID, err := d.assignStateKeyNID(ctx, nil, requestSenderUserID) - if err != nil { - return - } - - senderMembershipEventNID, senderMembership, err := - d.statements.selectMembershipFromRoomAndTarget( - ctx, roomNID, requestSenderUserNID, - ) - if err == sql.ErrNoRows { - // The user has never been a member of that room - return 0, false, nil - } else if err != nil { - return - } - - return senderMembershipEventNID, senderMembership == membershipStateJoin, nil -} - -// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB -func (d *Database) GetMembershipEventNIDsForRoom( - ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, -) ([]types.EventNID, error) { - if joinOnly { - return d.statements.selectMembershipsFromRoomAndMembership( - ctx, roomNID, membershipStateJoin, localOnly, - ) - } - - return d.statements.selectMembershipsFromRoom(ctx, roomNID, localOnly) -} - -// EventsFromIDs implements query.RoomserverQueryAPIEventDB -func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { - nidMap, err := d.EventNIDs(ctx, eventIDs) - if err != nil { - return nil, err - } - - var nids []types.EventNID - for _, nid := range nidMap { - nids = append(nids, nid) - } - - return d.Events(ctx, nids) -} - -func (d *Database) GetRoomVersionForRoom( - ctx context.Context, roomID string, -) (gomatrixserverlib.RoomVersion, error) { - return d.statements.selectRoomVersionForRoomID( - ctx, nil, roomID, - ) -} - -func (d *Database) GetRoomVersionForRoomNID( - ctx context.Context, roomNID types.RoomNID, -) (gomatrixserverlib.RoomVersion, error) { - return d.statements.selectRoomVersionForRoomNID( - ctx, nil, roomNID, - ) -} - -type transaction struct { - ctx context.Context - txn *sql.Tx -} - -// Commit implements types.Transaction -func (t *transaction) Commit() error { - return t.txn.Commit() -} - -// Rollback implements types.Transaction -func (t *transaction) Rollback() error { - return t.txn.Rollback() -} diff --git a/roomserver/storage/postgres/transactions_table.go b/roomserver/storage/postgres/transactions_table.go index 87c1cacae..5e59ae16d 100644 --- a/roomserver/storage/postgres/transactions_table.go +++ b/roomserver/storage/postgres/transactions_table.go @@ -18,6 +18,9 @@ package postgres import ( "context" "database/sql" + + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" ) const transactionsSchema = ` @@ -51,20 +54,21 @@ type transactionStatements struct { selectTransactionEventIDStmt *sql.Stmt } -func (s *transactionStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(transactionsSchema) +func NewPostgresTransactionsTable(db *sql.DB) (tables.Transactions, error) { + s := &transactionStatements{} + _, err := db.Exec(transactionsSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertTransactionStmt, insertTransactionSQL}, {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *transactionStatements) insertTransaction( - ctx context.Context, +func (s *transactionStatements) InsertTransaction( + ctx context.Context, txn *sql.Tx, transactionID string, sessionID int64, userID string, @@ -76,7 +80,7 @@ func (s *transactionStatements) insertTransaction( return } -func (s *transactionStatements) selectTransactionEventID( +func (s *transactionStatements) SelectTransactionEventID( ctx context.Context, transactionID string, sessionID int64, diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go new file mode 100644 index 000000000..5ddf6d84d --- /dev/null +++ b/roomserver/storage/shared/membership_updater.go @@ -0,0 +1,183 @@ +package shared + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +type membershipUpdater struct { + transaction + d *Database + roomNID types.RoomNID + targetUserNID types.EventStateKeyNID + membership tables.MembershipState +} + +func NewMembershipUpdater( + ctx context.Context, d *Database, roomID, targetUserID string, + targetLocal bool, roomVersion gomatrixserverlib.RoomVersion, + useTxns bool, +) (types.MembershipUpdater, error) { + txn, err := d.DB.Begin() + if err != nil { + return nil, err + } + succeeded := false + defer func() { + if !succeeded { + txn.Rollback() // nolint: errcheck + } + }() + + roomNID, err := d.assignRoomNID(ctx, txn, roomID, roomVersion) + if err != nil { + return nil, err + } + + targetUserNID, err := d.assignStateKeyNID(ctx, txn, targetUserID) + if err != nil { + return nil, err + } + + updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID, targetLocal) + if err != nil { + return nil, err + } + + succeeded = true + if !useTxns { + txn.Commit() // nolint: errcheck + updater.transaction.txn = nil + } + return updater, nil +} + +func (d *Database) membershipUpdaterTxn( + ctx context.Context, + txn *sql.Tx, + roomNID types.RoomNID, + targetUserNID types.EventStateKeyNID, + targetLocal bool, +) (*membershipUpdater, error) { + + if err := d.MembershipTable.InsertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { + return nil, err + } + + membership, err := d.MembershipTable.SelectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) + if err != nil { + return nil, err + } + + return &membershipUpdater{ + transaction{ctx, txn}, d, roomNID, targetUserNID, membership, + }, nil +} + +// IsInvite implements types.MembershipUpdater +func (u *membershipUpdater) IsInvite() bool { + return u.membership == tables.MembershipStateInvite +} + +// IsJoin implements types.MembershipUpdater +func (u *membershipUpdater) IsJoin() bool { + return u.membership == tables.MembershipStateJoin +} + +// IsLeave implements types.MembershipUpdater +func (u *membershipUpdater) IsLeave() bool { + return u.membership == tables.MembershipStateLeaveOrBan +} + +// SetToInvite implements types.MembershipUpdater +func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) + if err != nil { + return false, err + } + inserted, err := u.d.InvitesTable.InsertInviteEvent( + u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), + ) + if err != nil { + return false, err + } + if u.membership != tables.MembershipStateInvite { + if err = u.d.MembershipTable.UpdateMembership( + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, + ); err != nil { + return false, err + } + } + return inserted, nil +} + +// SetToJoin implements types.MembershipUpdater +func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) { + var inviteEventIDs []string + + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) + if err != nil { + return nil, err + } + + // If this is a join event update, there is no invite to update + if !isUpdate { + inviteEventIDs, err = u.d.InvitesTable.UpdateInviteRetired( + u.ctx, u.txn, u.roomNID, u.targetUserNID, + ) + if err != nil { + return nil, err + } + } + + // Look up the NID of the new join event + nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) + if err != nil { + return nil, err + } + + if u.membership != tables.MembershipStateJoin || isUpdate { + if err = u.d.MembershipTable.UpdateMembership( + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, + tables.MembershipStateJoin, nIDs[eventID], + ); err != nil { + return nil, err + } + } + + return inviteEventIDs, nil +} + +// SetToLeave implements types.MembershipUpdater +func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) + if err != nil { + return nil, err + } + inviteEventIDs, err := u.d.InvitesTable.UpdateInviteRetired( + u.ctx, u.txn, u.roomNID, u.targetUserNID, + ) + if err != nil { + return nil, err + } + + // Look up the NID of the new leave event + nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) + if err != nil { + return nil, err + } + + if u.membership != tables.MembershipStateLeaveOrBan { + if err = u.d.MembershipTable.UpdateMembership( + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, + tables.MembershipStateLeaveOrBan, nIDs[eventID], + ); err != nil { + return nil, err + } + } + return inviteEventIDs, nil +} diff --git a/roomserver/storage/shared/prepare.go b/roomserver/storage/shared/prepare.go new file mode 100644 index 000000000..65ceec1cc --- /dev/null +++ b/roomserver/storage/shared/prepare.go @@ -0,0 +1,60 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shared + +import ( + "context" + "database/sql" +) + +// StatementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement. +type StatementList []struct { + Statement **sql.Stmt + SQL string +} + +// Prepare the SQL for each statement in the list and assign the result to the prepared statement. +func (s StatementList) Prepare(db *sql.DB) (err error) { + for _, statement := range s { + if *statement.Statement, err = db.Prepare(statement.SQL); err != nil { + return + } + } + return +} + +type transaction struct { + ctx context.Context + txn *sql.Tx +} + +// Commit implements types.Transaction +func (t *transaction) Commit() error { + if t.txn == nil { + // The Updater structs can operate in useTxns=false mode. The code will still call this though. + return nil + } + return t.txn.Commit() +} + +// Rollback implements types.Transaction +func (t *transaction) Rollback() error { + if t.txn == nil { + // The Updater structs can operate in useTxns=false mode. The code will still call this though. + return nil + } + return t.txn.Rollback() +} diff --git a/roomserver/storage/shared/room_recent_events_updater.go b/roomserver/storage/shared/room_recent_events_updater.go new file mode 100644 index 000000000..8131f712d --- /dev/null +++ b/roomserver/storage/shared/room_recent_events_updater.go @@ -0,0 +1,120 @@ +package shared + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +type roomRecentEventsUpdater struct { + transaction + d *Database + roomNID types.RoomNID + latestEvents []types.StateAtEventAndReference + lastEventIDSent string + currentStateSnapshotNID types.StateSnapshotNID +} + +func NewRoomRecentEventsUpdater(d *Database, ctx context.Context, roomNID types.RoomNID, useTxns bool) (types.RoomRecentEventsUpdater, error) { + txn, err := d.DB.Begin() + if err != nil { + return nil, err + } + eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := + d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomNID) + if err != nil { + txn.Rollback() // nolint: errcheck + return nil, err + } + stateAndRefs, err := d.EventsTable.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) + if err != nil { + txn.Rollback() // nolint: errcheck + return nil, err + } + var lastEventIDSent string + if lastEventNIDSent != 0 { + lastEventIDSent, err = d.EventsTable.SelectEventID(ctx, txn, lastEventNIDSent) + if err != nil { + txn.Rollback() // nolint: errcheck + return nil, err + } + } + if !useTxns { + txn.Commit() // nolint: errcheck + txn = nil + } + return &roomRecentEventsUpdater{ + transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, + }, nil +} + +// RoomVersion implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) { + version, _ = u.d.GetRoomVersionForRoomNID(u.ctx, u.roomNID) + return +} + +// LatestEvents implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) LatestEvents() []types.StateAtEventAndReference { + return u.latestEvents +} + +// LastEventIDSent implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) LastEventIDSent() string { + return u.lastEventIDSent +} + +// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { + return u.currentStateSnapshotNID +} + +// StorePreviousEvents implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { + for _, ref := range previousEventReferences { + if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { + return err + } + } + return nil +} + +// IsReferenced implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { + err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) + if err == nil { + return true, nil + } + if err == sql.ErrNoRows { + return false, nil + } + return false, err +} + +// SetLatestEvents implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) SetLatestEvents( + roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, + currentStateSnapshotNID types.StateSnapshotNID, +) error { + eventNIDs := make([]types.EventNID, len(latest)) + for i := range latest { + eventNIDs[i] = latest[i].EventNID + } + return u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) +} + +// HasEventBeenSent implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { + return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID) +} + +// MarkEventAsSent implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { + return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, u.txn, eventNID) +} + +func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (types.MembershipUpdater, error) { + return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID, targetLocal) +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go new file mode 100644 index 000000000..bb5b51562 --- /dev/null +++ b/roomserver/storage/shared/storage.go @@ -0,0 +1,493 @@ +package shared + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +type Database struct { + DB *sql.DB + EventsTable tables.Events + EventJSONTable tables.EventJSON + EventTypesTable tables.EventTypes + EventStateKeysTable tables.EventStateKeys + RoomsTable tables.Rooms + TransactionsTable tables.Transactions + StateSnapshotTable tables.StateSnapshot + StateBlockTable tables.StateBlock + RoomAliasesTable tables.RoomAliases + PrevEventsTable tables.PreviousEvents + InvitesTable tables.Invites + MembershipTable tables.Membership +} + +func (d *Database) EventTypeNIDs( + ctx context.Context, eventTypes []string, +) (map[string]types.EventTypeNID, error) { + return d.EventTypesTable.BulkSelectEventTypeNID(ctx, eventTypes) +} + +func (d *Database) EventStateKeys( + ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, +) (map[types.EventStateKeyNID]string, error) { + return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, eventStateKeyNIDs) +} + +func (d *Database) EventStateKeyNIDs( + ctx context.Context, eventStateKeys []string, +) (map[string]types.EventStateKeyNID, error) { + return d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, eventStateKeys) +} + +func (d *Database) StateEntriesForEventIDs( + ctx context.Context, eventIDs []string, +) ([]types.StateEntry, error) { + return d.EventsTable.BulkSelectStateEventByID(ctx, eventIDs) +} + +func (d *Database) StateEntriesForTuples( + ctx context.Context, + stateBlockNIDs []types.StateBlockNID, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntryList, error) { + return d.StateBlockTable.BulkSelectFilteredStateBlockEntries( + ctx, stateBlockNIDs, stateKeyTuples, + ) +} + +func (d *Database) AddState( + ctx context.Context, + roomNID types.RoomNID, + stateBlockNIDs []types.StateBlockNID, + state []types.StateEntry, +) (stateNID types.StateSnapshotNID, err error) { + err = internal.WithTransaction(d.DB, func(txn *sql.Tx) error { + if len(state) > 0 { + var stateBlockNID types.StateBlockNID + stateBlockNID, err = d.StateBlockTable.BulkInsertStateData(ctx, txn, state) + if err != nil { + return err + } + stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) + } + stateNID, err = d.StateSnapshotTable.InsertState(ctx, txn, roomNID, stateBlockNIDs) + return err + }) + if err != nil { + return 0, err + } + return +} + +func (d *Database) EventNIDs( + ctx context.Context, eventIDs []string, +) (map[string]types.EventNID, error) { + return d.EventsTable.BulkSelectEventNID(ctx, eventIDs) +} + +func (d *Database) SetState( + ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, +) error { + return d.EventsTable.UpdateEventState(ctx, eventNID, stateNID) +} + +func (d *Database) StateAtEventIDs( + ctx context.Context, eventIDs []string, +) ([]types.StateAtEvent, error) { + return d.EventsTable.BulkSelectStateAtEventByID(ctx, eventIDs) +} + +func (d *Database) SnapshotNIDFromEventID( + ctx context.Context, eventID string, +) (types.StateSnapshotNID, error) { + _, stateNID, err := d.EventsTable.SelectEvent(ctx, nil, eventID) + return stateNID, err +} + +func (d *Database) EventIDs( + ctx context.Context, eventNIDs []types.EventNID, +) (map[types.EventNID]string, error) { + return d.EventsTable.BulkSelectEventID(ctx, eventNIDs) +} + +func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { + nidMap, err := d.EventNIDs(ctx, eventIDs) + if err != nil { + return nil, err + } + + var nids []types.EventNID + for _, nid := range nidMap { + nids = append(nids, nid) + } + + return d.Events(ctx, nids) +} + +func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) { + roomNID, err := d.RoomsTable.SelectRoomNID(ctx, nil, roomID) + if err == sql.ErrNoRows { + return 0, nil + } + return roomNID, err +} + +func (d *Database) RoomNIDExcludingStubs(ctx context.Context, roomID string) (roomNID types.RoomNID, err error) { + roomNID, err = d.RoomNID(ctx, roomID) + if err != nil { + return + } + latestEvents, _, err := d.RoomsTable.SelectLatestEventNIDs(ctx, nil, roomNID) + if err != nil { + return + } + if len(latestEvents) == 0 { + roomNID = 0 + return + } + return +} + +func (d *Database) LatestEventIDs( + ctx context.Context, roomNID types.RoomNID, +) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) { + err = internal.WithTransaction(d.DB, func(txn *sql.Tx) error { + var eventNIDs []types.EventNID + eventNIDs, currentStateSnapshotNID, err = d.RoomsTable.SelectLatestEventNIDs(ctx, txn, roomNID) + if err != nil { + return err + } + references, err = d.EventsTable.BulkSelectEventReference(ctx, txn, eventNIDs) + if err != nil { + return err + } + depth, err = d.EventsTable.SelectMaxEventDepth(ctx, txn, eventNIDs) + if err != nil { + return err + } + return nil + }) + return +} + +func (d *Database) StateBlockNIDs( + ctx context.Context, stateNIDs []types.StateSnapshotNID, +) ([]types.StateBlockNIDList, error) { + return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, stateNIDs) +} + +func (d *Database) StateEntries( + ctx context.Context, stateBlockNIDs []types.StateBlockNID, +) ([]types.StateEntryList, error) { + return d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs) +} + +func (d *Database) GetRoomVersionForRoom( + ctx context.Context, roomID string, +) (gomatrixserverlib.RoomVersion, error) { + return d.RoomsTable.SelectRoomVersionForRoomID( + ctx, nil, roomID, + ) +} + +func (d *Database) GetRoomVersionForRoomNID( + ctx context.Context, roomNID types.RoomNID, +) (gomatrixserverlib.RoomVersion, error) { + return d.RoomsTable.SelectRoomVersionForRoomNID( + ctx, roomNID, + ) +} + +func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { + return d.RoomAliasesTable.InsertRoomAlias(ctx, alias, roomID, creatorUserID) +} + +func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { + return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, alias) +} + +func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) { + return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, roomID) +} + +func (d *Database) GetCreatorIDForAlias( + ctx context.Context, alias string, +) (string, error) { + return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, alias) +} + +func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { + return d.RoomAliasesTable.DeleteRoomAlias(ctx, alias) +} + +func (d *Database) GetMembership( + ctx context.Context, roomNID types.RoomNID, requestSenderUserID string, +) (membershipEventNID types.EventNID, stillInRoom bool, err error) { + requestSenderUserNID, err := d.assignStateKeyNID(ctx, nil, requestSenderUserID) + if err != nil { + return + } + + senderMembershipEventNID, senderMembership, err := + d.MembershipTable.SelectMembershipFromRoomAndTarget( + ctx, roomNID, requestSenderUserNID, + ) + if err == sql.ErrNoRows { + // The user has never been a member of that room + return 0, false, nil + } else if err != nil { + return + } + + return senderMembershipEventNID, senderMembership == tables.MembershipStateJoin, nil +} + +func (d *Database) GetMembershipEventNIDsForRoom( + ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, +) ([]types.EventNID, error) { + if joinOnly { + return d.MembershipTable.SelectMembershipsFromRoomAndMembership( + ctx, roomNID, tables.MembershipStateJoin, localOnly, + ) + } + + return d.MembershipTable.SelectMembershipsFromRoom(ctx, roomNID, localOnly) +} + +func (d *Database) GetInvitesForUser( + ctx context.Context, + roomNID types.RoomNID, + targetUserNID types.EventStateKeyNID, +) (senderUserIDs []types.EventStateKeyNID, err error) { + return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) +} + +func (d *Database) Events( + ctx context.Context, eventNIDs []types.EventNID, +) ([]types.Event, error) { + eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs) + if err != nil { + return nil, err + } + results := make([]types.Event, len(eventJSONs)) + for i, eventJSON := range eventJSONs { + var roomNID types.RoomNID + var roomVersion gomatrixserverlib.RoomVersion + result := &results[i] + result.EventNID = eventJSON.EventNID + roomNID, err = d.EventsTable.SelectRoomNIDForEventNID(ctx, eventJSON.EventNID) + if err != nil { + return nil, err + } + roomVersion, err = d.RoomsTable.SelectRoomVersionForRoomNID(ctx, roomNID) + if err != nil { + return nil, err + } + result.Event, err = gomatrixserverlib.NewEventFromTrustedJSON( + eventJSON.EventJSON, false, roomVersion, + ) + if err != nil { + return nil, err + } + } + return results, nil +} + +func (d *Database) GetTransactionEventID( + ctx context.Context, transactionID string, + sessionID int64, userID string, +) (string, error) { + eventID, err := d.TransactionsTable.SelectTransactionEventID(ctx, transactionID, sessionID, userID) + if err == sql.ErrNoRows { + return "", nil + } + return eventID, err +} + +func (d *Database) MembershipUpdater( + ctx context.Context, roomID, targetUserID string, + targetLocal bool, roomVersion gomatrixserverlib.RoomVersion, +) (types.MembershipUpdater, error) { + return NewMembershipUpdater(ctx, d, roomID, targetUserID, targetLocal, roomVersion, true) +} + +func (d *Database) GetLatestEventsForUpdate( + ctx context.Context, roomNID types.RoomNID, +) (types.RoomRecentEventsUpdater, error) { + return NewRoomRecentEventsUpdater(d, ctx, roomNID, true) +} + +func (d *Database) StoreEvent( + ctx context.Context, event gomatrixserverlib.Event, + txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, +) (types.RoomNID, types.StateAtEvent, error) { + var ( + roomNID types.RoomNID + eventTypeNID types.EventTypeNID + eventStateKeyNID types.EventStateKeyNID + eventNID types.EventNID + stateNID types.StateSnapshotNID + err error + ) + + err = internal.WithTransaction(d.DB, func(txn *sql.Tx) error { + if txnAndSessionID != nil { + if err = d.TransactionsTable.InsertTransaction( + ctx, txn, txnAndSessionID.TransactionID, + txnAndSessionID.SessionID, event.Sender(), event.EventID(), + ); err != nil { + return err + } + } + + // TODO: Here we should aim to have two different code paths for new rooms + // vs existing ones. + + // Get the default room version. If the client doesn't supply a room_version + // then we will use our configured default to create the room. + // https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom + // Note that the below logic depends on the m.room.create event being the + // first event that is persisted to the database when creating or joining a + // room. + var roomVersion gomatrixserverlib.RoomVersion + if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil { + return err + } + + if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion); err != nil { + return err + } + + if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil { + return err + } + + eventStateKey := event.StateKey() + // Assigned a numeric ID for the state_key if there is one present. + // Otherwise set the numeric ID for the state_key to 0. + if eventStateKey != nil { + if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil { + return err + } + } + + if eventNID, stateNID, err = d.EventsTable.InsertEvent( + ctx, + txn, + roomNID, + eventTypeNID, + eventStateKeyNID, + event.EventID(), + event.EventReference().EventSHA256, + authEventNIDs, + event.Depth(), + ); err != nil { + if err == sql.ErrNoRows { + // We've already inserted the event so select the numeric event ID + eventNID, stateNID, err = d.EventsTable.SelectEvent(ctx, txn, event.EventID()) + } + if err != nil { + return err + } + } + + if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil { + return err + } + + return nil + }) + if err != nil { + return 0, types.StateAtEvent{}, err + } + + return roomNID, types.StateAtEvent{ + BeforeStateSnapshotNID: stateNID, + StateEntry: types.StateEntry{ + StateKeyTuple: types.StateKeyTuple{ + EventTypeNID: eventTypeNID, + EventStateKeyNID: eventStateKeyNID, + }, + EventNID: eventNID, + }, + }, nil +} + +func (d *Database) assignRoomNID( + ctx context.Context, txn *sql.Tx, + roomID string, roomVersion gomatrixserverlib.RoomVersion, +) (types.RoomNID, error) { + // Check if we already have a numeric ID in the database. + roomNID, err := d.RoomsTable.SelectRoomNID(ctx, txn, roomID) + if err == sql.ErrNoRows { + // We don't have a numeric ID so insert one into the database. + roomNID, err = d.RoomsTable.InsertRoomNID(ctx, txn, roomID, roomVersion) + if err == sql.ErrNoRows { + // We raced with another insert so run the select again. + roomNID, err = d.RoomsTable.SelectRoomNID(ctx, txn, roomID) + } + } + return roomNID, err +} + +func (d *Database) assignEventTypeNID( + ctx context.Context, txn *sql.Tx, eventType string, +) (eventTypeNID types.EventTypeNID, err error) { + // Check if we already have a numeric ID in the database. + eventTypeNID, err = d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType) + if err == sql.ErrNoRows { + // We don't have a numeric ID so insert one into the database. + eventTypeNID, err = d.EventTypesTable.InsertEventTypeNID(ctx, txn, eventType) + if err == sql.ErrNoRows { + // We raced with another insert so run the select again. + eventTypeNID, err = d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType) + } + } + return +} + +func (d *Database) assignStateKeyNID( + ctx context.Context, txn *sql.Tx, eventStateKey string, +) (types.EventStateKeyNID, error) { + // Check if we already have a numeric ID in the database. + eventStateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey) + if err == sql.ErrNoRows { + // We don't have a numeric ID so insert one into the database. + eventStateKeyNID, err = d.EventStateKeysTable.InsertEventStateKeyNID(ctx, txn, eventStateKey) + if err == sql.ErrNoRows { + // We raced with another insert so run the select again. + eventStateKeyNID, err = d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey) + } + } + return eventStateKeyNID, err +} + +func extractRoomVersionFromCreateEvent(event gomatrixserverlib.Event) ( + gomatrixserverlib.RoomVersion, error, +) { + var err error + var roomVersion gomatrixserverlib.RoomVersion + // Look for m.room.create events. + if event.Type() != gomatrixserverlib.MRoomCreate { + return gomatrixserverlib.RoomVersion(""), nil + } + roomVersion = gomatrixserverlib.RoomVersionV1 + var createContent gomatrixserverlib.CreateContent + // The m.room.create event contains an optional "room_version" key in + // the event content, so we need to unmarshal that first. + if err = json.Unmarshal(event.Content(), &createContent); err != nil { + return gomatrixserverlib.RoomVersion(""), err + } + // A room version was specified in the event content? + if createContent.RoomVersion != nil { + roomVersion = gomatrixserverlib.RoomVersion(*createContent.RoomVersion) + } + return roomVersion, err +} diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go index fbf35e711..7ff4c1d4d 100644 --- a/roomserver/storage/sqlite3/event_json_table.go +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -21,6 +21,8 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -51,40 +53,36 @@ type eventJSONStatements struct { bulkSelectEventJSONStmt *sql.Stmt } -func (s *eventJSONStatements) prepare(db *sql.DB) (err error) { +func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) { + s := &eventJSONStatements{} s.db = db - _, err = db.Exec(eventJSONSchema) + _, err := db.Exec(eventJSONSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertEventJSONStmt, insertEventJSONSQL}, {&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *eventJSONStatements) insertEventJSON( +func (s *eventJSONStatements) InsertEventJSON( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte, ) error { _, err := internal.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON) return err } -type eventJSONPair struct { - EventNID types.EventNID - EventJSON []byte -} - -func (s *eventJSONStatements) bulkSelectEventJSON( - ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, -) ([]eventJSONPair, error) { +func (s *eventJSONStatements) BulkSelectEventJSON( + ctx context.Context, eventNIDs []types.EventNID, +) ([]tables.EventJSONPair, error) { iEventNIDs := make([]interface{}, len(eventNIDs)) for k, v := range eventNIDs { iEventNIDs[k] = v } selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", internal.QueryVariadic(len(iEventNIDs)), 1) - rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...) + rows, err := s.db.QueryContext(ctx, selectOrig, iEventNIDs...) if err != nil { return nil, err } @@ -94,7 +92,7 @@ func (s *eventJSONStatements) bulkSelectEventJSON( // because of the unique constraint on event NIDs. // So we can allocate an array of the correct size now. // We might get fewer results than NIDs so we adjust the length of the slice before returning it. - results := make([]eventJSONPair, len(eventNIDs)) + results := make([]tables.EventJSONPair, len(eventNIDs)) i := 0 for ; rows.Next(); i++ { result := &results[i] diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go index f49ebf554..2d2c04d26 100644 --- a/roomserver/storage/sqlite3/event_state_keys_table.go +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -21,6 +21,8 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -67,44 +69,45 @@ type eventStateKeyStatements struct { bulkSelectEventStateKeyStmt *sql.Stmt } -func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) { +func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { + s := &eventStateKeyStatements{} s.db = db - _, err = db.Exec(eventStateKeysSchema) + _, err := db.Exec(eventStateKeysSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL}, {&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL}, {&s.bulkSelectEventStateKeyStmt, bulkSelectEventStateKeySQL}, - }.prepare(db) + }.Prepare(db) } -func (s *eventStateKeyStatements) insertEventStateKeyNID( +func (s *eventStateKeyStatements) InsertEventStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { var eventStateKeyNID int64 var err error var res sql.Result - insertStmt := txn.Stmt(s.insertEventStateKeyNIDStmt) + insertStmt := internal.TxStmt(txn, s.insertEventStateKeyNIDStmt) if res, err = insertStmt.ExecContext(ctx, eventStateKey); err == nil { eventStateKeyNID, err = res.LastInsertId() } return types.EventStateKeyNID(eventStateKeyNID), err } -func (s *eventStateKeyStatements) selectEventStateKeyNID( +func (s *eventStateKeyStatements) SelectEventStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { var eventStateKeyNID int64 - stmt := txn.Stmt(s.selectEventStateKeyNIDStmt) + stmt := internal.TxStmt(txn, s.selectEventStateKeyNIDStmt) err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID) return types.EventStateKeyNID(eventStateKeyNID), err } -func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( - ctx context.Context, txn *sql.Tx, eventStateKeys []string, +func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( + ctx context.Context, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { iEventStateKeys := make([]interface{}, len(eventStateKeys)) for k, v := range eventStateKeys { @@ -112,7 +115,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( } selectOrig := strings.Replace(bulkSelectEventStateKeySQL, "($1)", internal.QueryVariadic(len(eventStateKeys)), 1) - rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeys...) + rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeys...) if err != nil { return nil, err } @@ -129,8 +132,8 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( return result, nil } -func (s *eventStateKeyStatements) bulkSelectEventStateKey( - ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID, +func (s *eventStateKeyStatements) BulkSelectEventStateKey( + ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs)) for k, v := range eventStateKeyNIDs { @@ -138,7 +141,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKey( } selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", internal.QueryVariadic(len(eventStateKeyNIDs)), 1) - rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...) + rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go index 13abcd4df..6d229bbd7 100644 --- a/roomserver/storage/sqlite3/event_types_table.go +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -21,6 +21,8 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -81,22 +83,23 @@ type eventTypeStatements struct { bulkSelectEventTypeNIDStmt *sql.Stmt } -func (s *eventTypeStatements) prepare(db *sql.DB) (err error) { +func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) { + s := &eventTypeStatements{} s.db = db - _, err = db.Exec(eventTypesSchema) + _, err := db.Exec(eventTypesSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL}, {&s.insertEventTypeNIDResultStmt, insertEventTypeNIDResultSQL}, {&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL}, {&s.bulkSelectEventTypeNIDStmt, bulkSelectEventTypeNIDSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *eventTypeStatements) insertEventTypeNID( +func (s *eventTypeStatements) InsertEventTypeNID( ctx context.Context, tx *sql.Tx, eventType string, ) (types.EventTypeNID, error) { var eventTypeNID int64 @@ -109,7 +112,7 @@ func (s *eventTypeStatements) insertEventTypeNID( return types.EventTypeNID(eventTypeNID), err } -func (s *eventTypeStatements) selectEventTypeNID( +func (s *eventTypeStatements) SelectEventTypeNID( ctx context.Context, tx *sql.Tx, eventType string, ) (types.EventTypeNID, error) { var eventTypeNID int64 @@ -118,8 +121,8 @@ func (s *eventTypeStatements) selectEventTypeNID( return types.EventTypeNID(eventTypeNID), err } -func (s *eventTypeStatements) bulkSelectEventTypeNID( - ctx context.Context, tx *sql.Tx, eventTypes []string, +func (s *eventTypeStatements) BulkSelectEventTypeNID( + ctx context.Context, eventTypes []string, ) (map[string]types.EventTypeNID, error) { /////////////// iEventTypes := make([]interface{}, len(eventTypes)) @@ -133,8 +136,7 @@ func (s *eventTypeStatements) bulkSelectEventTypeNID( } /////////////// - selectStmt := internal.TxStmt(tx, selectPrep) - rows, err := selectStmt.QueryContext(ctx, iEventTypes...) + rows, err := selectPrep.QueryContext(ctx, iEventTypes...) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 551134950..491399ce7 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -23,6 +23,8 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -111,14 +113,15 @@ type eventStatements struct { selectRoomNIDForEventNIDStmt *sql.Stmt } -func (s *eventStatements) prepare(db *sql.DB) (err error) { +func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) { + s := &eventStatements{} s.db = db - _, err = db.Exec(eventsSchema) + _, err := db.Exec(eventsSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertEventStmt, insertEventSQL}, {&s.selectEventStmt, selectEventSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, @@ -132,10 +135,10 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) { {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, {&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *eventStatements) insertEvent( +func (s *eventStatements) InsertEvent( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, @@ -145,7 +148,7 @@ func (s *eventStatements) insertEvent( referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64, -) (types.EventNID, error) { +) (types.EventNID, types.StateSnapshotNID, error) { // attempt to insert: the last_row_id is the event NID insertStmt := internal.TxStmt(txn, s.insertEventStmt) result, err := insertStmt.ExecContext( @@ -153,17 +156,17 @@ func (s *eventStatements) insertEvent( eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, ) if err != nil { - return 0, err + return 0, 0, err } modified, err := result.RowsAffected() if modified == 0 && err == nil { - return 0, sql.ErrNoRows + return 0, 0, sql.ErrNoRows } eventNID, err := result.LastInsertId() - return types.EventNID(eventNID), err + return types.EventNID(eventNID), 0, err } -func (s *eventStatements) selectEvent( +func (s *eventStatements) SelectEvent( ctx context.Context, txn *sql.Tx, eventID string, ) (types.EventNID, types.StateSnapshotNID, error) { var eventNID int64 @@ -175,8 +178,8 @@ func (s *eventStatements) selectEvent( // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError -func (s *eventStatements) bulkSelectStateEventByID( - ctx context.Context, txn *sql.Tx, eventIDs []string, +func (s *eventStatements) BulkSelectStateEventByID( + ctx context.Context, eventIDs []string, ) ([]types.StateEntry, error) { /////////////// iEventIDs := make([]interface{}, len(eventIDs)) @@ -184,13 +187,12 @@ func (s *eventStatements) bulkSelectStateEventByID( iEventIDs[k] = v } selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", internal.QueryVariadic(len(iEventIDs)), 1) - selectPrep, err := txn.Prepare(selectOrig) + selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } /////////////// - selectStmt := internal.TxStmt(txn, selectPrep) rows, err := selectStmt.QueryContext(ctx, iEventIDs...) if err != nil { return nil, err @@ -228,8 +230,8 @@ func (s *eventStatements) bulkSelectStateEventByID( // bulkSelectStateAtEventByID lookups the state at a list of events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. -func (s *eventStatements) bulkSelectStateAtEventByID( - ctx context.Context, txn *sql.Tx, eventIDs []string, +func (s *eventStatements) BulkSelectStateAtEventByID( + ctx context.Context, eventIDs []string, ) ([]types.StateAtEvent, error) { /////////////// iEventIDs := make([]interface{}, len(eventIDs)) @@ -237,13 +239,11 @@ func (s *eventStatements) bulkSelectStateAtEventByID( iEventIDs[k] = v } selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", internal.QueryVariadic(len(iEventIDs)), 1) - selectPrep, err := txn.Prepare(selectOrig) + selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } /////////////// - - selectStmt := internal.TxStmt(txn, selectPrep) rows, err := selectStmt.QueryContext(ctx, iEventIDs...) if err != nil { return nil, err @@ -275,15 +275,14 @@ func (s *eventStatements) bulkSelectStateAtEventByID( return results, err } -func (s *eventStatements) updateEventState( - ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID, +func (s *eventStatements) UpdateEventState( + ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { - updateStmt := internal.TxStmt(txn, s.updateEventStateStmt) - _, err := updateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID)) + _, err := s.updateEventStateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID)) return err } -func (s *eventStatements) selectEventSentToOutput( +func (s *eventStatements) SelectEventSentToOutput( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (sentToOutput bool, err error) { selectStmt := internal.TxStmt(txn, s.selectEventSentToOutputStmt) @@ -294,14 +293,14 @@ func (s *eventStatements) selectEventSentToOutput( return } -func (s *eventStatements) updateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { +func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { updateStmt := internal.TxStmt(txn, s.updateEventSentToOutputStmt) _, err := updateStmt.ExecContext(ctx, int64(eventNID)) //_, err := s.updateEventSentToOutputStmt.ExecContext(ctx, int64(eventNID)) return err } -func (s *eventStatements) selectEventID( +func (s *eventStatements) SelectEventID( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (eventID string, err error) { selectStmt := internal.TxStmt(txn, s.selectEventIDStmt) @@ -309,7 +308,7 @@ func (s *eventStatements) selectEventID( return } -func (s *eventStatements) bulkSelectStateAtEventAndReference( +func (s *eventStatements) BulkSelectStateAtEventAndReference( ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]types.StateAtEventAndReference, error) { /////////////// @@ -355,7 +354,7 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference( return results, nil } -func (s *eventStatements) bulkSelectEventReference( +func (s *eventStatements) BulkSelectEventReference( ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]gomatrixserverlib.EventReference, error) { /////////////// @@ -391,20 +390,19 @@ func (s *eventStatements) bulkSelectEventReference( } // bulkSelectEventID returns a map from numeric event ID to string event ID. -func (s *eventStatements) bulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { +func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { /////////////// iEventNIDs := make([]interface{}, len(eventNIDs)) for k, v := range eventNIDs { iEventNIDs[k] = v } selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", internal.QueryVariadic(len(iEventNIDs)), 1) - selectPrep, err := txn.Prepare(selectOrig) + selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } /////////////// - selectStmt := internal.TxStmt(txn, selectPrep) rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) if err != nil { return nil, err @@ -428,20 +426,18 @@ func (s *eventStatements) bulkSelectEventID(ctx context.Context, txn *sql.Tx, ev // bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { +func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { /////////////// iEventIDs := make([]interface{}, len(eventIDs)) for k, v := range eventIDs { iEventIDs[k] = v } selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", internal.QueryVariadic(len(iEventIDs)), 1) - selectPrep, err := txn.Prepare(selectOrig) + selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } /////////////// - - selectStmt := internal.TxStmt(txn, selectPrep) rows, err := selectStmt.QueryContext(ctx, iEventIDs...) if err != nil { return nil, err @@ -459,7 +455,7 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e return results, nil } -func (s *eventStatements) selectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) { +func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) { var result int64 iEventIDs := make([]interface{}, len(eventNIDs)) for i, v := range eventNIDs { @@ -473,11 +469,10 @@ func (s *eventStatements) selectMaxEventDepth(ctx context.Context, txn *sql.Tx, return result, nil } -func (s *eventStatements) selectRoomNIDForEventNID( - ctx context.Context, txn *sql.Tx, eventNID types.EventNID, +func (s *eventStatements) SelectRoomNIDForEventNID( + ctx context.Context, eventNID types.EventNID, ) (roomNID types.RoomNID, err error) { - selectStmt := internal.TxStmt(txn, s.selectRoomNIDForEventNIDStmt) - err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&roomNID) + err = s.selectRoomNIDForEventNIDStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&roomNID) return } diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go index a42d18a73..7dcc2dc0b 100644 --- a/roomserver/storage/sqlite3/invite_table.go +++ b/roomserver/storage/sqlite3/invite_table.go @@ -20,6 +20,8 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -66,28 +68,28 @@ type inviteStatements struct { selectInvitesAboutToRetireStmt *sql.Stmt } -func (s *inviteStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(inviteSchema) +func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) { + s := &inviteStatements{} + _, err := db.Exec(inviteSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertInviteEventStmt, insertInviteEventSQL}, {&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL}, {&s.updateInviteRetiredStmt, updateInviteRetiredSQL}, {&s.selectInvitesAboutToRetireStmt, selectInvitesAboutToRetireSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *inviteStatements) insertInviteEvent( +func (s *inviteStatements) InsertInviteEvent( ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte, ) (bool, error) { stmt := internal.TxStmt(txn, s.insertInviteEventStmt) - defer stmt.Close() // nolint: errcheck result, err := stmt.ExecContext( ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON, ) @@ -101,12 +103,12 @@ func (s *inviteStatements) insertInviteEvent( return count != 0, nil } -func (s *inviteStatements) updateInviteRetired( +func (s *inviteStatements) UpdateInviteRetired( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventIDs []string, err error) { // gather all the event IDs we will retire - stmt := txn.Stmt(s.selectInvitesAboutToRetireStmt) + stmt := internal.TxStmt(txn, s.selectInvitesAboutToRetireStmt) rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) if err != nil { return nil, err @@ -121,13 +123,13 @@ func (s *inviteStatements) updateInviteRetired( } // now retire the invites - stmt = txn.Stmt(s.updateInviteRetiredStmt) + stmt = internal.TxStmt(txn, s.updateInviteRetiredStmt) _, err = stmt.ExecContext(ctx, roomNID, targetUserNID) return } // selectInviteActiveForUserInRoom returns a list of sender state key NIDs -func (s *inviteStatements) selectInviteActiveForUserInRoom( +func (s *inviteStatements) SelectInviteActiveForUserInRoom( ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, ) ([]types.EventStateKeyNID, error) { diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 34108af43..5f3076c5a 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -20,17 +20,11 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) -type membershipState int64 - -const ( - membershipStateLeaveOrBan membershipState = 1 - membershipStateInvite membershipState = 2 - membershipStateJoin membershipState = 3 -) - const membershipSchema = ` CREATE TABLE IF NOT EXISTS roomserver_membership ( room_nid INTEGER NOT NULL, @@ -91,13 +85,14 @@ type membershipStatements struct { updateMembershipStmt *sql.Stmt } -func (s *membershipStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(membershipSchema) +func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { + s := &membershipStatements{} + _, err := db.Exec(membershipSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertMembershipStmt, insertMembershipSQL}, {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, @@ -106,10 +101,10 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *membershipStatements) insertMembership( +func (s *membershipStatements) InsertMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool, @@ -119,10 +114,10 @@ func (s *membershipStatements) insertMembership( return err } -func (s *membershipStatements) selectMembershipForUpdate( +func (s *membershipStatements) SelectMembershipForUpdate( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, -) (membership membershipState, err error) { +) (membership tables.MembershipState, err error) { stmt := internal.TxStmt(txn, s.selectMembershipForUpdateStmt) err = stmt.QueryRowContext( ctx, roomNID, targetUserNID, @@ -130,26 +125,25 @@ func (s *membershipStatements) selectMembershipForUpdate( return } -func (s *membershipStatements) selectMembershipFromRoomAndTarget( - ctx context.Context, txn *sql.Tx, +func (s *membershipStatements) SelectMembershipFromRoomAndTarget( + ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, -) (eventNID types.EventNID, membership membershipState, err error) { - selectStmt := internal.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt) - err = selectStmt.QueryRowContext( +) (eventNID types.EventNID, membership tables.MembershipState, err error) { + err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( ctx, roomNID, targetUserNID, ).Scan(&membership, &eventNID) return } -func (s *membershipStatements) selectMembershipsFromRoom( - ctx context.Context, txn *sql.Tx, +func (s *membershipStatements) SelectMembershipsFromRoom( + ctx context.Context, roomNID types.RoomNID, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var selectStmt *sql.Stmt if localOnly { - selectStmt = internal.TxStmt(txn, s.selectLocalMembershipsFromRoomStmt) + selectStmt = s.selectLocalMembershipsFromRoomStmt } else { - selectStmt = internal.TxStmt(txn, s.selectMembershipsFromRoomStmt) + selectStmt = s.selectMembershipsFromRoomStmt } rows, err := selectStmt.QueryContext(ctx, roomNID) if err != nil { @@ -167,15 +161,15 @@ func (s *membershipStatements) selectMembershipsFromRoom( return } -func (s *membershipStatements) selectMembershipsFromRoomAndMembership( - ctx context.Context, txn *sql.Tx, - roomNID types.RoomNID, membership membershipState, localOnly bool, +func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( + ctx context.Context, + roomNID types.RoomNID, membership tables.MembershipState, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var stmt *sql.Stmt if localOnly { - stmt = internal.TxStmt(txn, s.selectLocalMembershipsFromRoomAndMembershipStmt) + stmt = s.selectLocalMembershipsFromRoomAndMembershipStmt } else { - stmt = internal.TxStmt(txn, s.selectMembershipsFromRoomAndMembershipStmt) + stmt = s.selectMembershipsFromRoomAndMembershipStmt } rows, err := stmt.QueryContext(ctx, roomNID, membership) if err != nil { @@ -193,10 +187,10 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership( return } -func (s *membershipStatements) updateMembership( +func (s *membershipStatements) UpdateMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, - senderUserNID types.EventStateKeyNID, membership membershipState, + senderUserNID types.EventStateKeyNID, membership tables.MembershipState, eventNID types.EventNID, ) error { stmt := internal.TxStmt(txn, s.updateMembershipStmt) diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go index f344bda95..2b187bba3 100644 --- a/roomserver/storage/sqlite3/previous_events_table.go +++ b/roomserver/storage/sqlite3/previous_events_table.go @@ -20,6 +20,8 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -55,19 +57,20 @@ type previousEventStatements struct { selectPreviousEventExistsStmt *sql.Stmt } -func (s *previousEventStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(previousEventSchema) +func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { + s := &previousEventStatements{} + _, err := db.Exec(previousEventSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertPreviousEventStmt, insertPreviousEventSQL}, {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *previousEventStatements) insertPreviousEvent( +func (s *previousEventStatements) InsertPreviousEvent( ctx context.Context, txn *sql.Tx, previousEventID string, @@ -83,7 +86,7 @@ func (s *previousEventStatements) insertPreviousEvent( // Check if the event reference exists // Returns sql.ErrNoRows if the event reference doesn't exist. -func (s *previousEventStatements) selectPreviousEventExists( +func (s *previousEventStatements) SelectPreviousEventExists( ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte, ) error { var ok int64 diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go index 592ef9782..da5f9161a 100644 --- a/roomserver/storage/sqlite3/room_aliases_table.go +++ b/roomserver/storage/sqlite3/room_aliases_table.go @@ -20,6 +20,8 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" ) const roomAliasesSchema = ` @@ -60,45 +62,43 @@ type roomAliasesStatements struct { deleteRoomAliasStmt *sql.Stmt } -func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(roomAliasesSchema) +func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { + s := &roomAliasesStatements{} + _, err := db.Exec(roomAliasesSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertRoomAliasStmt, insertRoomAliasSQL}, {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL}, {&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL}, {&s.selectCreatorIDFromAliasStmt, selectCreatorIDFromAliasSQL}, {&s.deleteRoomAliasStmt, deleteRoomAliasSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *roomAliasesStatements) insertRoomAlias( - ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string, +func (s *roomAliasesStatements) InsertRoomAlias( + ctx context.Context, alias string, roomID string, creatorUserID string, ) (err error) { - insertStmt := internal.TxStmt(txn, s.insertRoomAliasStmt) - _, err = insertStmt.ExecContext(ctx, alias, roomID, creatorUserID) + _, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID) return } -func (s *roomAliasesStatements) selectRoomIDFromAlias( - ctx context.Context, txn *sql.Tx, alias string, +func (s *roomAliasesStatements) SelectRoomIDFromAlias( + ctx context.Context, alias string, ) (roomID string, err error) { - selectStmt := internal.TxStmt(txn, s.selectRoomIDFromAliasStmt) - err = selectStmt.QueryRowContext(ctx, alias).Scan(&roomID) + err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) if err == sql.ErrNoRows { return "", nil } return } -func (s *roomAliasesStatements) selectAliasesFromRoomID( - ctx context.Context, txn *sql.Tx, roomID string, +func (s *roomAliasesStatements) SelectAliasesFromRoomID( + ctx context.Context, roomID string, ) (aliases []string, err error) { aliases = []string{} - selectStmt := internal.TxStmt(txn, s.selectAliasesFromRoomIDStmt) - rows, err := selectStmt.QueryContext(ctx, roomID) + rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) if err != nil { return } @@ -117,21 +117,19 @@ func (s *roomAliasesStatements) selectAliasesFromRoomID( return } -func (s *roomAliasesStatements) selectCreatorIDFromAlias( - ctx context.Context, txn *sql.Tx, alias string, +func (s *roomAliasesStatements) SelectCreatorIDFromAlias( + ctx context.Context, alias string, ) (creatorID string, err error) { - selectStmt := internal.TxStmt(txn, s.selectCreatorIDFromAliasStmt) - err = selectStmt.QueryRowContext(ctx, alias).Scan(&creatorID) + err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID) if err == sql.ErrNoRows { return "", nil } return } -func (s *roomAliasesStatements) deleteRoomAlias( - ctx context.Context, txn *sql.Tx, alias string, +func (s *roomAliasesStatements) DeleteRoomAlias( + ctx context.Context, alias string, ) (err error) { - deleteStmt := internal.TxStmt(txn, s.deleteRoomAliasStmt) - _, err = deleteStmt.ExecContext(ctx, alias) + _, err = s.deleteRoomAliasStmt.ExecContext(ctx, alias) return } diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index ea949d1e9..d22fceb3b 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -22,6 +22,8 @@ import ( "errors" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -71,12 +73,13 @@ type roomStatements struct { selectRoomVersionForRoomNIDStmt *sql.Stmt } -func (s *roomStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(roomsSchema) +func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { + s := &roomStatements{} + _, err := db.Exec(roomsSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertRoomNIDStmt, insertRoomNIDSQL}, {&s.selectRoomNIDStmt, selectRoomNIDSQL}, {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, @@ -84,23 +87,23 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, {&s.selectRoomVersionForRoomIDStmt, selectRoomVersionForRoomIDSQL}, {&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *roomStatements) insertRoomNID( +func (s *roomStatements) InsertRoomNID( ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion, ) (types.RoomNID, error) { var err error insertStmt := internal.TxStmt(txn, s.insertRoomNIDStmt) if _, err = insertStmt.ExecContext(ctx, roomID, roomVersion); err == nil { - return s.selectRoomNID(ctx, txn, roomID) + return s.SelectRoomNID(ctx, txn, roomID) } else { return types.RoomNID(0), err } } -func (s *roomStatements) selectRoomNID( +func (s *roomStatements) SelectRoomNID( ctx context.Context, txn *sql.Tx, roomID string, ) (types.RoomNID, error) { var roomNID int64 @@ -109,7 +112,7 @@ func (s *roomStatements) selectRoomNID( return types.RoomNID(roomNID), err } -func (s *roomStatements) selectLatestEventNIDs( +func (s *roomStatements) SelectLatestEventNIDs( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.StateSnapshotNID, error) { var eventNIDs []types.EventNID @@ -126,7 +129,7 @@ func (s *roomStatements) selectLatestEventNIDs( return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil } -func (s *roomStatements) selectLatestEventsNIDsForUpdate( +func (s *roomStatements) SelectLatestEventsNIDsForUpdate( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) { var eventNIDs []types.EventNID @@ -144,7 +147,7 @@ func (s *roomStatements) selectLatestEventsNIDsForUpdate( return eventNIDs, types.EventNID(lastEventSentNID), types.StateSnapshotNID(stateSnapshotNID), nil } -func (s *roomStatements) updateLatestEventNIDs( +func (s *roomStatements) UpdateLatestEventNIDs( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, @@ -163,7 +166,7 @@ func (s *roomStatements) updateLatestEventNIDs( return err } -func (s *roomStatements) selectRoomVersionForRoomID( +func (s *roomStatements) SelectRoomVersionForRoomID( ctx context.Context, txn *sql.Tx, roomID string, ) (gomatrixserverlib.RoomVersion, error) { var roomVersion gomatrixserverlib.RoomVersion @@ -175,12 +178,11 @@ func (s *roomStatements) selectRoomVersionForRoomID( return roomVersion, err } -func (s *roomStatements) selectRoomVersionForRoomNID( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +func (s *roomStatements) SelectRoomVersionForRoomNID( + ctx context.Context, roomNID types.RoomNID, ) (gomatrixserverlib.RoomVersion, error) { var roomVersion gomatrixserverlib.RoomVersion - stmt := internal.TxStmt(txn, s.selectRoomVersionForRoomNIDStmt) - err := stmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion) + err := s.selectRoomVersionForRoomNIDStmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion) if err == sql.ErrNoRows { return roomVersion, errors.New("room not found") } diff --git a/roomserver/storage/sqlite3/sql.go b/roomserver/storage/sqlite3/sql.go deleted file mode 100644 index 0d49432b8..000000000 --- a/roomserver/storage/sqlite3/sql.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "database/sql" -) - -type statements struct { - eventTypeStatements - eventStateKeyStatements - roomStatements - eventStatements - eventJSONStatements - stateSnapshotStatements - stateBlockStatements - previousEventStatements - roomAliasesStatements - inviteStatements - membershipStatements - transactionStatements -} - -func (s *statements) prepare(db *sql.DB) error { - var err error - - for _, prepare := range []func(db *sql.DB) error{ - s.eventTypeStatements.prepare, - s.eventStateKeyStatements.prepare, - s.roomStatements.prepare, - s.eventStatements.prepare, - s.eventJSONStatements.prepare, - s.stateSnapshotStatements.prepare, - s.stateBlockStatements.prepare, - s.previousEventStatements.prepare, - s.roomAliasesStatements.prepare, - s.inviteStatements.prepare, - s.membershipStatements.prepare, - s.transactionStatements.prepare, - } { - if err = prepare(db); err != nil { - return err - } - } - - return nil -} diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 861d76cf9..399fdbf63 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -23,6 +23,8 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" ) @@ -77,22 +79,23 @@ type stateBlockStatements struct { bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt } -func (s *stateBlockStatements) prepare(db *sql.DB) (err error) { +func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) { + s := &stateBlockStatements{} s.db = db - _, err = db.Exec(stateDataSchema) + _, err := db.Exec(stateDataSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertStateDataStmt, insertStateDataSQL}, {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL}, {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, {&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *stateBlockStatements) bulkInsertStateData( +func (s *stateBlockStatements) BulkInsertStateData( ctx context.Context, txn *sql.Tx, entries []types.StateEntry, ) (types.StateBlockNID, error) { @@ -120,19 +123,18 @@ func (s *stateBlockStatements) bulkInsertStateData( return stateBlockNID, nil } -func (s *stateBlockStatements) bulkSelectStateBlockEntries( - ctx context.Context, txn *sql.Tx, stateBlockNIDs []types.StateBlockNID, +func (s *stateBlockStatements) BulkSelectStateBlockEntries( + ctx context.Context, stateBlockNIDs []types.StateBlockNID, ) ([]types.StateEntryList, error) { nids := make([]interface{}, len(stateBlockNIDs)) for k, v := range stateBlockNIDs { nids[k] = v } selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", internal.QueryVariadic(len(nids)), 1) - selectPrep, err := s.db.Prepare(selectOrig) + selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } - selectStmt := internal.TxStmt(txn, selectPrep) rows, err := selectStmt.QueryContext(ctx, nids...) if err != nil { return nil, err @@ -174,8 +176,8 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries( return results, nil } -func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries( - ctx context.Context, txn *sql.Tx, // nolint: unparam +func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries( + ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntryList, error) { diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 8682627a0..5de3f639a 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -23,6 +23,8 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -51,20 +53,21 @@ type stateSnapshotStatements struct { bulkSelectStateBlockNIDsStmt *sql.Stmt } -func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) { +func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { + s := &stateSnapshotStatements{} s.db = db - _, err = db.Exec(stateSnapshotSchema) + _, err := db.Exec(stateSnapshotSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertStateStmt, insertStateSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *stateSnapshotStatements) insertState( +func (s *stateSnapshotStatements) InsertState( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, ) (stateNID types.StateSnapshotNID, err error) { stateBlockNIDsJSON, err := json.Marshal(stateBlockNIDs) @@ -82,15 +85,15 @@ func (s *stateSnapshotStatements) insertState( return } -func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs( - ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID, +func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( + ctx context.Context, stateNIDs []types.StateSnapshotNID, ) ([]types.StateBlockNIDList, error) { nids := make([]interface{}, len(stateNIDs)) for k, v := range stateNIDs { nids[k] = v } selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", internal.QueryVariadic(len(nids)), 1) - selectStmt, err := txn.Prepare(selectOrig) + selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index bb38f800f..54a2f2658 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -18,14 +18,14 @@ package sqlite3 import ( "context" "database/sql" - "encoding/json" "errors" "net/url" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" _ "github.com/mattn/go-sqlite3" @@ -33,11 +33,21 @@ import ( // A Database is used to store room events and stream offsets. type Database struct { - statements statements - db *sql.DB + shared.Database + events tables.Events + eventJSON tables.EventJSON + eventTypes tables.EventTypes + eventStateKeys tables.EventStateKeys + rooms tables.Rooms + transactions tables.Transactions + prevEvents tables.PreviousEvents + invites tables.Invites + membership tables.Membership + db *sql.DB } // Open a sqlite database. +// nolint: gocyclo func Open(dataSourceName string) (*Database, error) { var d Database uri, err := url.Parse(dataSourceName) @@ -63,904 +73,94 @@ func Open(dataSourceName string) (*Database, error) { // acquire the global mutex and never unlock it because it is waiting for a connection // which it will never obtain. d.db.SetMaxOpenConns(20) - if err = d.statements.prepare(d.db); err != nil { + + d.eventStateKeys, err = NewSqliteEventStateKeysTable(d.db) + if err != nil { return nil, err } + d.eventTypes, err = NewSqliteEventTypesTable(d.db) + if err != nil { + return nil, err + } + d.eventJSON, err = NewSqliteEventJSONTable(d.db) + if err != nil { + return nil, err + } + d.events, err = NewSqliteEventsTable(d.db) + if err != nil { + return nil, err + } + d.rooms, err = NewSqliteRoomsTable(d.db) + if err != nil { + return nil, err + } + d.transactions, err = NewSqliteTransactionsTable(d.db) + if err != nil { + return nil, err + } + stateBlock, err := NewSqliteStateBlockTable(d.db) + if err != nil { + return nil, err + } + stateSnapshot, err := NewSqliteStateSnapshotTable(d.db) + if err != nil { + return nil, err + } + d.prevEvents, err = NewSqlitePrevEventsTable(d.db) + if err != nil { + return nil, err + } + roomAliases, err := NewSqliteRoomAliasesTable(d.db) + if err != nil { + return nil, err + } + d.invites, err = NewSqliteInvitesTable(d.db) + if err != nil { + return nil, err + } + d.membership, err = NewSqliteMembershipTable(d.db) + if err != nil { + return nil, err + } + d.Database = shared.Database{ + DB: d.db, + EventsTable: d.events, + EventTypesTable: d.eventTypes, + EventStateKeysTable: d.eventStateKeys, + EventJSONTable: d.eventJSON, + RoomsTable: d.rooms, + TransactionsTable: d.transactions, + StateBlockTable: stateBlock, + StateSnapshotTable: stateSnapshot, + PrevEventsTable: d.prevEvents, + RoomAliasesTable: roomAliases, + InvitesTable: d.invites, + MembershipTable: d.membership, + } return &d, nil } -// StoreEvent implements input.EventDatabase -func (d *Database) StoreEvent( - ctx context.Context, event gomatrixserverlib.Event, - txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, -) (types.RoomNID, types.StateAtEvent, error) { - var ( - roomNID types.RoomNID - eventTypeNID types.EventTypeNID - eventStateKeyNID types.EventStateKeyNID - eventNID types.EventNID - stateNID types.StateSnapshotNID - err error - ) - - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - if txnAndSessionID != nil { - if err = d.statements.insertTransaction( - ctx, txn, txnAndSessionID.TransactionID, - txnAndSessionID.SessionID, event.Sender(), event.EventID(), - ); err != nil { - return err - } - } - - // TODO: Here we should aim to have two different code paths for new rooms - // vs existing ones. - - // Get the default room version. If the client doesn't supply a room_version - // then we will use our configured default to create the room. - // https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom - // Note that the below logic depends on the m.room.create event being the - // first event that is persisted to the database when creating or joining a - // room. - var roomVersion gomatrixserverlib.RoomVersion - if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil { - return err - } - - if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion); err != nil { - return err - } - - if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil { - return err - } - - eventStateKey := event.StateKey() - // Assigned a numeric ID for the state_key if there is one present. - // Otherwise set the numeric ID for the state_key to 0. - if eventStateKey != nil { - if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil { - return err - } - } - - if eventNID, err = d.statements.insertEvent( - ctx, - txn, - roomNID, - eventTypeNID, - eventStateKeyNID, - event.EventID(), - event.EventReference().EventSHA256, - authEventNIDs, - event.Depth(), - ); err != nil { - if err == sql.ErrNoRows { - // We've already inserted the event so select the numeric event ID - eventNID, stateNID, err = d.statements.selectEvent(ctx, txn, event.EventID()) - } - if err != nil { - return err - } - } - - if err = d.statements.insertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil { - return err - } - - return nil - }) - if err != nil { - return 0, types.StateAtEvent{}, err - } - - return roomNID, types.StateAtEvent{ - BeforeStateSnapshotNID: stateNID, - StateEntry: types.StateEntry{ - StateKeyTuple: types.StateKeyTuple{ - EventTypeNID: eventTypeNID, - EventStateKeyNID: eventStateKeyNID, - }, - EventNID: eventNID, - }, - }, nil -} - -func extractRoomVersionFromCreateEvent(event gomatrixserverlib.Event) ( - gomatrixserverlib.RoomVersion, error, -) { - var err error - var roomVersion gomatrixserverlib.RoomVersion - // Look for m.room.create events. - if event.Type() != gomatrixserverlib.MRoomCreate { - return gomatrixserverlib.RoomVersion(""), nil - } - roomVersion = gomatrixserverlib.RoomVersionV1 - var createContent gomatrixserverlib.CreateContent - // The m.room.create event contains an optional "room_version" key in - // the event content, so we need to unmarshal that first. - if err = json.Unmarshal(event.Content(), &createContent); err != nil { - return gomatrixserverlib.RoomVersion(""), err - } - // A room version was specified in the event content? - if createContent.RoomVersion != nil { - roomVersion = gomatrixserverlib.RoomVersion(*createContent.RoomVersion) - } - return roomVersion, err -} - -func (d *Database) assignRoomNID( - ctx context.Context, txn *sql.Tx, - roomID string, roomVersion gomatrixserverlib.RoomVersion, -) (roomNID types.RoomNID, err error) { - // Check if we already have a numeric ID in the database. - roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) - if err == sql.ErrNoRows { - // We don't have a numeric ID so insert one into the database. - roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID, roomVersion) - if err == nil { - // Now get the numeric ID back out of the database - roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) - } - } - return -} - -func (d *Database) assignEventTypeNID( - ctx context.Context, txn *sql.Tx, eventType string, -) (eventTypeNID types.EventTypeNID, err error) { - // Check if we already have a numeric ID in the database. - eventTypeNID, err = d.statements.selectEventTypeNID(ctx, txn, eventType) - if err == sql.ErrNoRows { - // We don't have a numeric ID so insert one into the database. - eventTypeNID, err = d.statements.insertEventTypeNID(ctx, txn, eventType) - if err == sql.ErrNoRows { - // We raced with another insert so run the select again. - eventTypeNID, err = d.statements.selectEventTypeNID(ctx, txn, eventType) - } - } - return -} - -func (d *Database) assignStateKeyNID( - ctx context.Context, txn *sql.Tx, eventStateKey string, -) (eventStateKeyNID types.EventStateKeyNID, err error) { - // Check if we already have a numeric ID in the database. - eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) - if err == sql.ErrNoRows { - // We don't have a numeric ID so insert one into the database. - eventStateKeyNID, err = d.statements.insertEventStateKeyNID(ctx, txn, eventStateKey) - if err == sql.ErrNoRows { - // We raced with another insert so run the select again. - eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) - } - } - return -} - -// StateEntriesForEventIDs implements input.EventDatabase -func (d *Database) StateEntriesForEventIDs( - ctx context.Context, eventIDs []string, -) (se []types.StateEntry, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - se, err = d.statements.bulkSelectStateEventByID(ctx, txn, eventIDs) - return err - }) - return -} - -// EventTypeNIDs implements state.RoomStateDatabase -func (d *Database) EventTypeNIDs( - ctx context.Context, eventTypes []string, -) (etnids map[string]types.EventTypeNID, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - etnids, err = d.statements.bulkSelectEventTypeNID(ctx, txn, eventTypes) - return err - }) - return -} - -// EventStateKeyNIDs implements state.RoomStateDatabase -func (d *Database) EventStateKeyNIDs( - ctx context.Context, eventStateKeys []string, -) (esknids map[string]types.EventStateKeyNID, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - esknids, err = d.statements.bulkSelectEventStateKeyNID(ctx, txn, eventStateKeys) - return err - }) - return -} - -// EventStateKeys implements query.RoomserverQueryAPIDatabase -func (d *Database) EventStateKeys( - ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, -) (out map[types.EventStateKeyNID]string, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - out, err = d.statements.bulkSelectEventStateKey(ctx, txn, eventStateKeyNIDs) - return err - }) - return -} - -// EventNIDs implements query.RoomserverQueryAPIDatabase -func (d *Database) EventNIDs( - ctx context.Context, eventIDs []string, -) (out map[string]types.EventNID, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - out, err = d.statements.bulkSelectEventNID(ctx, txn, eventIDs) - return err - }) - return -} - -// Events implements input.EventDatabase -func (d *Database) Events( - ctx context.Context, eventNIDs []types.EventNID, -) ([]types.Event, error) { - var eventJSONs []eventJSONPair - var err error - var results []types.Event - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - eventJSONs, err = d.statements.bulkSelectEventJSON(ctx, txn, eventNIDs) - if err != nil || len(eventJSONs) == 0 { - return nil - } - results = make([]types.Event, len(eventJSONs)) - for i, eventJSON := range eventJSONs { - var roomNID types.RoomNID - var roomVersion gomatrixserverlib.RoomVersion - result := &results[i] - result.EventNID = eventJSON.EventNID - roomNID, err = d.statements.selectRoomNIDForEventNID(ctx, txn, eventJSON.EventNID) - if err != nil { - return err - } - roomVersion, err = d.statements.selectRoomVersionForRoomNID(ctx, txn, roomNID) - if err != nil { - return err - } - result.Event, err = gomatrixserverlib.NewEventFromTrustedJSON( - eventJSON.EventJSON, false, roomVersion, - ) - if err != nil { - return nil - } - } - return nil - }) - if err != nil { - return []types.Event{}, err - } - return results, nil -} - -// AddState implements input.EventDatabase -func (d *Database) AddState( - ctx context.Context, - roomNID types.RoomNID, - stateBlockNIDs []types.StateBlockNID, - state []types.StateEntry, -) (stateNID types.StateSnapshotNID, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - if len(state) > 0 { - var stateBlockNID types.StateBlockNID - stateBlockNID, err = d.statements.bulkInsertStateData(ctx, txn, state) - if err != nil { - return err - } - stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) - } - stateNID, err = d.statements.insertState(ctx, txn, roomNID, stateBlockNIDs) - return err - }) - if err != nil { - return 0, err - } - return -} - -// SetState implements input.EventDatabase -func (d *Database) SetState( - ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, -) error { - e := internal.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.statements.updateEventState(ctx, txn, eventNID, stateNID) - }) - return e -} - -// StateAtEventIDs implements input.EventDatabase -func (d *Database) StateAtEventIDs( - ctx context.Context, eventIDs []string, -) (se []types.StateAtEvent, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - se, err = d.statements.bulkSelectStateAtEventByID(ctx, txn, eventIDs) - return err - }) - return -} - -// StateBlockNIDs implements state.RoomStateDatabase -func (d *Database) StateBlockNIDs( - ctx context.Context, stateNIDs []types.StateSnapshotNID, -) (sl []types.StateBlockNIDList, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - sl, err = d.statements.bulkSelectStateBlockNIDs(ctx, txn, stateNIDs) - return err - }) - return -} - -// StateEntries implements state.RoomStateDatabase -func (d *Database) StateEntries( - ctx context.Context, stateBlockNIDs []types.StateBlockNID, -) (sel []types.StateEntryList, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - sel, err = d.statements.bulkSelectStateBlockEntries(ctx, txn, stateBlockNIDs) - return err - }) - return -} - -// SnapshotNIDFromEventID implements state.RoomStateDatabase -func (d *Database) SnapshotNIDFromEventID( - ctx context.Context, eventID string, -) (stateNID types.StateSnapshotNID, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - _, stateNID, err = d.statements.selectEvent(ctx, txn, eventID) - return err - }) - return -} - -// EventIDs implements input.RoomEventDatabase -func (d *Database) EventIDs( - ctx context.Context, eventNIDs []types.EventNID, -) (out map[types.EventNID]string, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - out, err = d.statements.bulkSelectEventID(ctx, txn, eventNIDs) - return err - }) - return -} - -// GetLatestEventsForUpdate implements input.EventDatabase func (d *Database) GetLatestEventsForUpdate( ctx context.Context, roomNID types.RoomNID, ) (types.RoomRecentEventsUpdater, error) { - txn, err := d.db.Begin() - if err != nil { - return nil, err - } - eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := - d.statements.selectLatestEventsNIDsForUpdate(ctx, txn, roomNID) - if err != nil { - txn.Rollback() // nolint: errcheck - return nil, err - } - stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) - if err != nil { - txn.Rollback() // nolint: errcheck - return nil, err - } - var lastEventIDSent string - if lastEventNIDSent != 0 { - lastEventIDSent, err = d.statements.selectEventID(ctx, txn, lastEventNIDSent) - if err != nil { - txn.Rollback() // nolint: errcheck - return nil, err - } - } - - // FIXME: we probably want to support long-lived txns in sqlite somehow, but we don't because we get - // 'database is locked' errors caused by multiple write txns (one being the long-lived txn created here) - // so for now let's not use a long-lived txn at all, and just commit it here and set the txn to nil so - // we fail fast if someone tries to use the underlying txn object. - err = txn.Commit() - if err != nil { - return nil, err - } - return &roomRecentEventsUpdater{ - transaction{ctx, nil}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, - }, nil + // TODO: Do not use transactions. We should be holding open this transaction but we cannot have + // multiple write transactions on sqlite. The code will perform additional + // write transactions independent of this one which will consistently cause + // 'database is locked' errors. As sqlite doesn't support multi-process on the + // same DB anyway, and we only execute updates sequentially, the only worries + // are for rolling back when things go wrong. (atomicity) + return shared.NewRoomRecentEventsUpdater(&d.Database, ctx, roomNID, false) } -// GetTransactionEventID implements input.EventDatabase -func (d *Database) GetTransactionEventID( - ctx context.Context, transactionID string, - sessionID int64, userID string, -) (string, error) { - eventID, err := d.statements.selectTransactionEventID(ctx, nil, transactionID, sessionID, userID) - if err == sql.ErrNoRows { - return "", nil - } - return eventID, err -} - -type roomRecentEventsUpdater struct { - transaction - d *Database - roomNID types.RoomNID - latestEvents []types.StateAtEventAndReference - lastEventIDSent string - currentStateSnapshotNID types.StateSnapshotNID -} - -// RoomVersion implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) { - version, _ = u.d.GetRoomVersionForRoomNID(u.ctx, u.roomNID) - return -} - -// LatestEvents implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) LatestEvents() []types.StateAtEventAndReference { - return u.latestEvents -} - -// LastEventIDSent implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) LastEventIDSent() string { - return u.lastEventIDSent -} - -// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { - return u.currentStateSnapshotNID -} - -// StorePreviousEvents implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { - err := internal.WithTransaction(u.d.db, func(txn *sql.Tx) error { - for _, ref := range previousEventReferences { - if err := u.d.statements.insertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { - return err - } - } - return nil - }) - return err -} - -// IsReferenced implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (res bool, err error) { - err = internal.WithTransaction(u.d.db, func(txn *sql.Tx) error { - err := u.d.statements.selectPreviousEventExists(u.ctx, txn, eventReference.EventID, eventReference.EventSHA256) - if err == nil { - res = true - err = nil - } - if err == sql.ErrNoRows { - res = false - err = nil - } - return err - }) - return -} - -// SetLatestEvents implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) SetLatestEvents( - roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, - currentStateSnapshotNID types.StateSnapshotNID, -) error { - err := internal.WithTransaction(u.d.db, func(txn *sql.Tx) error { - eventNIDs := make([]types.EventNID, len(latest)) - for i := range latest { - eventNIDs[i] = latest[i].EventNID - } - return u.d.statements.updateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) - }) - return err -} - -// HasEventBeenSent implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (res bool, err error) { - err = internal.WithTransaction(u.d.db, func(txn *sql.Tx) error { - res, err = u.d.statements.selectEventSentToOutput(u.ctx, txn, eventNID) - return err - }) - return -} - -// MarkEventAsSent implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { - err := internal.WithTransaction(u.d.db, func(txn *sql.Tx) error { - return u.d.statements.updateEventSentToOutput(u.ctx, txn, eventNID) - }) - return err -} - -func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (mu types.MembershipUpdater, err error) { - err = internal.WithTransaction(u.d.db, func(txn *sql.Tx) error { - mu, err = u.d.membershipUpdaterTxn(u.ctx, txn, u.roomNID, targetUserNID, targetLocal) - return err - }) - return -} - -// RoomNID implements query.RoomserverQueryAPIDB -func (d *Database) RoomNID(ctx context.Context, roomID string) (roomNID types.RoomNID, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) - if err == sql.ErrNoRows { - roomNID = 0 - err = nil - } - return err - }) - return -} - -// RoomNIDExcludingStubs implements query.RoomserverQueryAPIDB -func (d *Database) RoomNIDExcludingStubs(ctx context.Context, roomID string) (roomNID types.RoomNID, err error) { - roomNID, err = d.RoomNID(ctx, roomID) - if err != nil { - return - } - latestEvents, _, err := d.statements.selectLatestEventNIDs(ctx, nil, roomNID) - if err != nil { - return - } - if len(latestEvents) == 0 { - roomNID = 0 - return - } - return -} - -// LatestEventIDs implements query.RoomserverQueryAPIDatabase -func (d *Database) LatestEventIDs( - ctx context.Context, roomNID types.RoomNID, -) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - var eventNIDs []types.EventNID - eventNIDs, currentStateSnapshotNID, err = d.statements.selectLatestEventNIDs(ctx, txn, roomNID) - if err != nil { - return err - } - references, err = d.statements.bulkSelectEventReference(ctx, txn, eventNIDs) - if err != nil { - return err - } - depth, err = d.statements.selectMaxEventDepth(ctx, txn, eventNIDs) - if err != nil { - return err - } - return nil - }) - return -} - -// GetInvitesForUser implements query.RoomserverQueryAPIDatabase -func (d *Database) GetInvitesForUser( - ctx context.Context, - roomNID types.RoomNID, - targetUserNID types.EventStateKeyNID, -) (senderUserIDs []types.EventStateKeyNID, err error) { - return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) -} - -// SetRoomAlias implements alias.RoomserverAliasAPIDB -func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { - return d.statements.insertRoomAlias(ctx, nil, alias, roomID, creatorUserID) -} - -// GetRoomIDForAlias implements alias.RoomserverAliasAPIDB -func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { - return d.statements.selectRoomIDFromAlias(ctx, nil, alias) -} - -// GetAliasesForRoomID implements alias.RoomserverAliasAPIDB -func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) { - return d.statements.selectAliasesFromRoomID(ctx, nil, roomID) -} - -// GetCreatorIDForAlias implements alias.RoomserverAliasAPIDB -func (d *Database) GetCreatorIDForAlias( - ctx context.Context, alias string, -) (string, error) { - return d.statements.selectCreatorIDFromAlias(ctx, nil, alias) -} - -// RemoveRoomAlias implements alias.RoomserverAliasAPIDB -func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { - return d.statements.deleteRoomAlias(ctx, nil, alias) -} - -// StateEntriesForTuples implements state.RoomStateDatabase -func (d *Database) StateEntriesForTuples( - ctx context.Context, - stateBlockNIDs []types.StateBlockNID, - stateKeyTuples []types.StateKeyTuple, -) ([]types.StateEntryList, error) { - return d.statements.bulkSelectFilteredStateBlockEntries( - ctx, nil, stateBlockNIDs, stateKeyTuples, - ) -} - -// MembershipUpdater implements input.RoomEventDatabase func (d *Database) MembershipUpdater( ctx context.Context, roomID, targetUserID string, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion, ) (updater types.MembershipUpdater, err error) { - var txn *sql.Tx - txn, err = d.db.Begin() - if err != nil { - return nil, err - } - succeeded := false - defer func() { - if !succeeded { - txn.Rollback() // nolint: errcheck - } else { - // TODO: We should be holding open this transaction but we cannot have - // multiple write transactions on sqlite. The code will perform additional - // write transactions independent of this one which will consistently cause - // 'database is locked' errors. For now, we'll break up the transaction and - // hope we don't race too catastrophically. Long term, we should be able to - // thread in txn objects where appropriate (either at the interface level or - // bring matrix business logic into the storage layer). - txerr := txn.Commit() - if err == nil && txerr != nil { - err = txerr - } - } - }() - - roomNID, err := d.assignRoomNID(ctx, txn, roomID, roomVersion) - if err != nil { - return nil, err - } - - targetUserNID, err := d.assignStateKeyNID(ctx, txn, targetUserID) - if err != nil { - return nil, err - } - - updater, err = d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID, targetLocal) - if err != nil { - return nil, err - } - - succeeded = true - return updater, nil -} - -type membershipUpdater struct { - transaction - d *Database - roomNID types.RoomNID - targetUserNID types.EventStateKeyNID - membership membershipState -} - -func (d *Database) membershipUpdaterTxn( - ctx context.Context, - txn *sql.Tx, - roomNID types.RoomNID, - targetUserNID types.EventStateKeyNID, - targetLocal bool, -) (types.MembershipUpdater, error) { - - if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { - return nil, err - } - - membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) - if err != nil { - return nil, err - } - - return &membershipUpdater{ - // purposefully set the txn to nil so if we try to use it we panic and fail fast - transaction{ctx, nil}, d, roomNID, targetUserNID, membership, - }, nil -} - -// IsInvite implements types.MembershipUpdater -func (u *membershipUpdater) IsInvite() bool { - return u.membership == membershipStateInvite -} - -// IsJoin implements types.MembershipUpdater -func (u *membershipUpdater) IsJoin() bool { - return u.membership == membershipStateJoin -} - -// IsLeave implements types.MembershipUpdater -func (u *membershipUpdater) IsLeave() bool { - return u.membership == membershipStateLeaveOrBan -} - -// SetToInvite implements types.MembershipUpdater -func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (inserted bool, err error) { - err = internal.WithTransaction(u.d.db, func(txn *sql.Tx) error { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, event.Sender()) - if err != nil { - return err - } - inserted, err = u.d.statements.insertInviteEvent( - u.ctx, txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), - ) - if err != nil { - return err - } - if u.membership != membershipStateInvite { - if err = u.d.statements.updateMembership( - u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0, - ); err != nil { - return err - } - } - return nil - }) - return -} - -// SetToJoin implements types.MembershipUpdater -func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) (inviteEventIDs []string, err error) { - err = internal.WithTransaction(u.d.db, func(txn *sql.Tx) error { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, senderUserID) - if err != nil { - return err - } - - // If this is a join event update, there is no invite to update - if !isUpdate { - inviteEventIDs, err = u.d.statements.updateInviteRetired( - u.ctx, txn, u.roomNID, u.targetUserNID, - ) - if err != nil { - return err - } - } - - // Look up the NID of the new join event - nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) - if err != nil { - return err - } - - if u.membership != membershipStateJoin || isUpdate { - if err = u.d.statements.updateMembership( - u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, - membershipStateJoin, nIDs[eventID], - ); err != nil { - return err - } - } - return nil - }) - - return -} - -// SetToLeave implements types.MembershipUpdater -func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) (inviteEventIDs []string, err error) { - err = internal.WithTransaction(u.d.db, func(txn *sql.Tx) error { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, senderUserID) - if err != nil { - return err - } - inviteEventIDs, err = u.d.statements.updateInviteRetired( - u.ctx, txn, u.roomNID, u.targetUserNID, - ) - if err != nil { - return err - } - - // Look up the NID of the new leave event - nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) - if err != nil { - return err - } - - if u.membership != membershipStateLeaveOrBan { - if err = u.d.statements.updateMembership( - u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, - membershipStateLeaveOrBan, nIDs[eventID], - ); err != nil { - return err - } - } - return nil - }) - return -} - -// GetMembership implements query.RoomserverQueryAPIDB -func (d *Database) GetMembership( - ctx context.Context, roomNID types.RoomNID, requestSenderUserID string, -) (membershipEventNID types.EventNID, stillInRoom bool, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - requestSenderUserNID, err := d.assignStateKeyNID(ctx, txn, requestSenderUserID) - if err != nil { - return err - } - - membershipEventNID, _, err = - d.statements.selectMembershipFromRoomAndTarget( - ctx, txn, roomNID, requestSenderUserNID, - ) - if err == sql.ErrNoRows { - // The user has never been a member of that room - return nil - } - if err != nil { - return err - } - stillInRoom = true - return nil - }) - - return -} - -// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB -func (d *Database) GetMembershipEventNIDsForRoom( - ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, -) (eventNIDs []types.EventNID, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - if joinOnly { - eventNIDs, err = d.statements.selectMembershipsFromRoomAndMembership( - ctx, txn, roomNID, membershipStateJoin, localOnly, - ) - return nil - } - - eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID, localOnly) - return nil - }) - return -} - -// EventsFromIDs implements query.RoomserverQueryAPIEventDB -func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { - nidMap, err := d.EventNIDs(ctx, eventIDs) - if err != nil { - return nil, err - } - - var nids []types.EventNID - for _, nid := range nidMap { - nids = append(nids, nid) - } - - return d.Events(ctx, nids) -} - -func (d *Database) GetRoomVersionForRoom( - ctx context.Context, roomID string, -) (gomatrixserverlib.RoomVersion, error) { - return d.statements.selectRoomVersionForRoomID( - ctx, nil, roomID, - ) -} - -func (d *Database) GetRoomVersionForRoomNID( - ctx context.Context, roomNID types.RoomNID, -) (gomatrixserverlib.RoomVersion, error) { - return d.statements.selectRoomVersionForRoomNID( - ctx, nil, roomNID, - ) -} - -type transaction struct { - ctx context.Context - txn *sql.Tx -} - -// Commit implements types.Transaction -func (t *transaction) Commit() error { - if t.txn == nil { - return nil - } - return t.txn.Commit() -} - -// Rollback implements types.Transaction -func (t *transaction) Rollback() error { - if t.txn == nil { - return nil - } - return t.txn.Rollback() + // TODO: Do not use transactions. We should be holding open this transaction but we cannot have + // multiple write transactions on sqlite. The code will perform additional + // write transactions independent of this one which will consistently cause + // 'database is locked' errors. As sqlite doesn't support multi-process on the + // same DB anyway, and we only execute updates sequentially, the only worries + // are for rolling back when things go wrong. (atomicity) + return shared.NewMembershipUpdater(ctx, &d.Database, roomID, targetUserID, targetLocal, roomVersion, false) } diff --git a/roomserver/storage/sqlite3/transactions_table.go b/roomserver/storage/sqlite3/transactions_table.go index d22c73845..7b489ac6e 100644 --- a/roomserver/storage/sqlite3/transactions_table.go +++ b/roomserver/storage/sqlite3/transactions_table.go @@ -20,6 +20,8 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" ) const transactionsSchema = ` @@ -46,19 +48,20 @@ type transactionStatements struct { selectTransactionEventIDStmt *sql.Stmt } -func (s *transactionStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(transactionsSchema) +func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) { + s := &transactionStatements{} + _, err := db.Exec(transactionsSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertTransactionStmt, insertTransactionSQL}, {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *transactionStatements) insertTransaction( +func (s *transactionStatements) InsertTransaction( ctx context.Context, txn *sql.Tx, transactionID string, sessionID int64, @@ -72,14 +75,13 @@ func (s *transactionStatements) insertTransaction( return } -func (s *transactionStatements) selectTransactionEventID( - ctx context.Context, txn *sql.Tx, +func (s *transactionStatements) SelectTransactionEventID( + ctx context.Context, transactionID string, sessionID int64, userID string, ) (eventID string, err error) { - stmt := internal.TxStmt(txn, s.selectTransactionEventIDStmt) - err = stmt.QueryRowContext( + err = s.selectTransactionEventIDStmt.QueryRowContext( ctx, transactionID, sessionID, userID, ).Scan(&eventID) return diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go new file mode 100644 index 000000000..11cff8a8b --- /dev/null +++ b/roomserver/storage/tables/interface.go @@ -0,0 +1,122 @@ +package tables + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +type EventJSONPair struct { + EventNID types.EventNID + EventJSON []byte +} + +type EventJSON interface { + InsertEventJSON(ctx context.Context, tx *sql.Tx, eventNID types.EventNID, eventJSON []byte) error + BulkSelectEventJSON(ctx context.Context, eventNIDs []types.EventNID) ([]EventJSONPair, error) +} + +type EventTypes interface { + InsertEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error) + SelectEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error) + BulkSelectEventTypeNID(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) +} + +type EventStateKeys interface { + InsertEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) + SelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) + BulkSelectEventStateKeyNID(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) + BulkSelectEventStateKey(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) +} + +type Events interface { + InsertEvent(c context.Context, txn *sql.Tx, i types.RoomNID, j types.EventTypeNID, k types.EventStateKeyNID, eventID string, referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64) (types.EventNID, types.StateSnapshotNID, error) + SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error) + // bulkSelectStateEventByID lookups a list of state events by event ID. + // If any of the requested events are missing from the database it returns a types.MissingEventError + BulkSelectStateEventByID(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) + // BulkSelectStateAtEventByID lookups the state at a list of events by event ID. + // If any of the requested events are missing from the database it returns a types.MissingEventError. + // If we do not have the state for any of the requested events it returns a types.MissingEventError. + BulkSelectStateAtEventByID(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) + UpdateEventState(ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID) error + SelectEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error) + UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error + SelectEventID(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (eventID string, err error) + BulkSelectStateAtEventAndReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) + BulkSelectEventReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error) + // BulkSelectEventID returns a map from numeric event ID to string event ID. + BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) + // BulkSelectEventNIDs returns a map from string event ID to numeric event ID. + // If an event ID is not in the database then it is omitted from the map. + BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) + SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) + SelectRoomNIDForEventNID(ctx context.Context, eventNID types.EventNID) (roomNID types.RoomNID, err error) +} + +type Rooms interface { + InsertRoomNID(ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion) (types.RoomNID, error) + SelectRoomNID(ctx context.Context, txn *sql.Tx, roomID string) (types.RoomNID, error) + SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) + SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) + UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error + SelectRoomVersionForRoomID(ctx context.Context, txn *sql.Tx, roomID string) (gomatrixserverlib.RoomVersion, error) + SelectRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error) +} + +type Transactions interface { + InsertTransaction(ctx context.Context, txn *sql.Tx, transactionID string, sessionID int64, userID string, eventID string) error + SelectTransactionEventID(ctx context.Context, transactionID string, sessionID int64, userID string) (eventID string, err error) +} + +type StateSnapshot interface { + InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID) (stateNID types.StateSnapshotNID, err error) + BulkSelectStateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) +} + +type StateBlock interface { + BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries []types.StateEntry) (types.StateBlockNID, error) + BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) + BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) +} + +type RoomAliases interface { + InsertRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) (err error) + SelectRoomIDFromAlias(ctx context.Context, alias string) (roomID string, err error) + SelectAliasesFromRoomID(ctx context.Context, roomID string) ([]string, error) + SelectCreatorIDFromAlias(ctx context.Context, alias string) (creatorID string, err error) + DeleteRoomAlias(ctx context.Context, alias string) (err error) +} + +type PreviousEvents interface { + InsertPreviousEvent(ctx context.Context, txn *sql.Tx, previousEventID string, previousEventReferenceSHA256 []byte, eventNID types.EventNID) error + // Check if the event reference exists + // Returns sql.ErrNoRows if the event reference doesn't exist. + SelectPreviousEventExists(ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte) error +} + +type Invites interface { + InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error) + UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error) + // SelectInviteActiveForUserInRoom returns a list of sender state key NIDs + SelectInviteActiveForUserInRoom(ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, error) +} + +type MembershipState int64 + +const ( + MembershipStateLeaveOrBan MembershipState = 1 + MembershipStateInvite MembershipState = 2 + MembershipStateJoin MembershipState = 3 +) + +type Membership interface { + InsertMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool) error + SelectMembershipForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (MembershipState, error) + SelectMembershipFromRoomAndTarget(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, error) + SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) + SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) + UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID) error +} diff --git a/roomserver/version/version.go b/roomserver/version/version.go index ddd0f23a6..f2b15ec39 100644 --- a/roomserver/version/version.go +++ b/roomserver/version/version.go @@ -20,42 +20,6 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) -// RoomVersionDescription contains information about a room version, -// namely whether it is marked as supported or stable in this server -// version. -// A version is supported if the server has some support for rooms -// that are this version. A version is marked as stable or unstable -// in order to hint whether the version should be used to clients -// calling the /capabilities endpoint. -// https://matrix.org/docs/spec/client_server/r0.6.0#get-matrix-client-r0-capabilities -type RoomVersionDescription struct { - Supported bool - Stable bool -} - -var roomVersions = map[gomatrixserverlib.RoomVersion]RoomVersionDescription{ - gomatrixserverlib.RoomVersionV1: RoomVersionDescription{ - Supported: true, - Stable: true, - }, - gomatrixserverlib.RoomVersionV2: RoomVersionDescription{ - Supported: true, - Stable: true, - }, - gomatrixserverlib.RoomVersionV3: RoomVersionDescription{ - Supported: true, - Stable: true, - }, - gomatrixserverlib.RoomVersionV4: RoomVersionDescription{ - Supported: true, - Stable: true, - }, - gomatrixserverlib.RoomVersionV5: RoomVersionDescription{ - Supported: true, - Stable: true, - }, -} - // DefaultRoomVersion contains the room version that will, by // default, be used to create new rooms on this server. func DefaultRoomVersion() gomatrixserverlib.RoomVersion { @@ -64,43 +28,37 @@ func DefaultRoomVersion() gomatrixserverlib.RoomVersion { // RoomVersions returns a map of all known room versions to this // server. -func RoomVersions() map[gomatrixserverlib.RoomVersion]RoomVersionDescription { - return roomVersions +func RoomVersions() map[gomatrixserverlib.RoomVersion]gomatrixserverlib.RoomVersionDescription { + return gomatrixserverlib.RoomVersions() } // SupportedRoomVersions returns a map of descriptions for room // versions that are supported by this homeserver. -func SupportedRoomVersions() map[gomatrixserverlib.RoomVersion]RoomVersionDescription { - versions := make(map[gomatrixserverlib.RoomVersion]RoomVersionDescription) - for id, version := range RoomVersions() { - if version.Supported { - versions[id] = version - } - } - return versions +func SupportedRoomVersions() map[gomatrixserverlib.RoomVersion]gomatrixserverlib.RoomVersionDescription { + return gomatrixserverlib.SupportedRoomVersions() } // RoomVersion returns information about a specific room version. // An UnknownVersionError is returned if the version is not known // to the server. -func RoomVersion(version gomatrixserverlib.RoomVersion) (RoomVersionDescription, error) { - if version, ok := roomVersions[version]; ok { +func RoomVersion(version gomatrixserverlib.RoomVersion) (gomatrixserverlib.RoomVersionDescription, error) { + if version, ok := gomatrixserverlib.RoomVersions()[version]; ok { return version, nil } - return RoomVersionDescription{}, UnknownVersionError{version} + return gomatrixserverlib.RoomVersionDescription{}, UnknownVersionError{version} } // SupportedRoomVersion returns information about a specific room // version. An UnknownVersionError is returned if the version is not // known to the server, or an UnsupportedVersionError is returned if // the version is known but specifically marked as unsupported. -func SupportedRoomVersion(version gomatrixserverlib.RoomVersion) (RoomVersionDescription, error) { +func SupportedRoomVersion(version gomatrixserverlib.RoomVersion) (gomatrixserverlib.RoomVersionDescription, error) { result, err := RoomVersion(version) if err != nil { - return RoomVersionDescription{}, err + return gomatrixserverlib.RoomVersionDescription{}, err } if !result.Supported { - return RoomVersionDescription{}, UnsupportedVersionError{version} + return gomatrixserverlib.RoomVersionDescription{}, UnsupportedVersionError{version} } return result, nil } diff --git a/serverkeyapi/api/api.go b/serverkeyapi/api/api.go new file mode 100644 index 000000000..f108d437b --- /dev/null +++ b/serverkeyapi/api/api.go @@ -0,0 +1,113 @@ +package api + +import ( + "context" + "errors" + "net/http" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/gomatrixserverlib" +) + +type ServerKeyInternalAPI interface { + gomatrixserverlib.KeyDatabase + + KeyRing() *gomatrixserverlib.KeyRing + + InputPublicKeys( + ctx context.Context, + request *InputPublicKeysRequest, + response *InputPublicKeysResponse, + ) error + + QueryPublicKeys( + ctx context.Context, + request *QueryPublicKeysRequest, + response *QueryPublicKeysResponse, + ) error +} + +// NewRoomserverInputAPIHTTP creates a RoomserverInputAPI implemented by talking to a HTTP POST API. +// If httpClient is nil an error is returned +func NewServerKeyInternalAPIHTTP( + serverKeyAPIURL string, + httpClient *http.Client, + immutableCache caching.ImmutableCache, +) (ServerKeyInternalAPI, error) { + if httpClient == nil { + return nil, errors.New("NewRoomserverInternalAPIHTTP: httpClient is ") + } + return &httpServerKeyInternalAPI{ + serverKeyAPIURL: serverKeyAPIURL, + httpClient: httpClient, + immutableCache: immutableCache, + }, nil +} + +type httpServerKeyInternalAPI struct { + ServerKeyInternalAPI + + serverKeyAPIURL string + httpClient *http.Client + immutableCache caching.ImmutableCache +} + +func (s *httpServerKeyInternalAPI) KeyRing() *gomatrixserverlib.KeyRing { + // This is a bit of a cheat - we tell gomatrixserverlib that this API is + // both the key database and the key fetcher. While this does have the + // rather unfortunate effect of preventing gomatrixserverlib from handling + // key fetchers directly, we can at least reimplement this behaviour on + // the other end of the API. + return &gomatrixserverlib.KeyRing{ + KeyDatabase: s, + KeyFetchers: []gomatrixserverlib.KeyFetcher{s}, + } +} + +func (s *httpServerKeyInternalAPI) FetcherName() string { + return "httpServerKeyInternalAPI" +} + +func (s *httpServerKeyInternalAPI) StoreKeys( + ctx context.Context, + results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, +) error { + request := InputPublicKeysRequest{ + Keys: make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult), + } + response := InputPublicKeysResponse{} + for req, res := range results { + request.Keys[req] = res + s.immutableCache.StoreServerKey(req, res) + } + return s.InputPublicKeys(ctx, &request, &response) +} + +func (s *httpServerKeyInternalAPI) FetchKeys( + ctx context.Context, + requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, +) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { + result := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) + request := QueryPublicKeysRequest{ + Requests: make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp), + } + response := QueryPublicKeysResponse{ + Results: make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult), + } + for req, ts := range requests { + if res, ok := s.immutableCache.GetServerKey(req); ok { + result[req] = res + continue + } + request.Requests[req] = ts + } + err := s.QueryPublicKeys(ctx, &request, &response) + if err != nil { + return nil, err + } + for req, res := range response.Results { + result[req] = res + s.immutableCache.StoreServerKey(req, res) + } + return result, nil +} diff --git a/serverkeyapi/api/http.go b/serverkeyapi/api/http.go new file mode 100644 index 000000000..ab89f7bfc --- /dev/null +++ b/serverkeyapi/api/http.go @@ -0,0 +1,57 @@ +package api + +import ( + "context" + + commonHTTP "github.com/matrix-org/dendrite/internal/http" + "github.com/matrix-org/gomatrixserverlib" + + "github.com/opentracing/opentracing-go" +) + +const ( + // ServerKeyInputPublicKeyPath is the HTTP path for the InputPublicKeys API. + ServerKeyInputPublicKeyPath = "/serverkeyapi/inputPublicKey" + + // ServerKeyQueryPublicKeyPath is the HTTP path for the QueryPublicKeys API. + ServerKeyQueryPublicKeyPath = "/serverkeyapi/queryPublicKey" +) + +type InputPublicKeysRequest struct { + Keys map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult `json:"keys"` +} + +type InputPublicKeysResponse struct { +} + +func (h *httpServerKeyInternalAPI) InputPublicKeys( + ctx context.Context, + request *InputPublicKeysRequest, + response *InputPublicKeysResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "InputPublicKey") + defer span.Finish() + + apiURL := h.serverKeyAPIURL + ServerKeyInputPublicKeyPath + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +type QueryPublicKeysRequest struct { + Requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp `json:"requests"` +} + +type QueryPublicKeysResponse struct { + Results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult `json:"results"` +} + +func (h *httpServerKeyInternalAPI) QueryPublicKeys( + ctx context.Context, + request *QueryPublicKeysRequest, + response *QueryPublicKeysResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPublicKey") + defer span.Finish() + + apiURL := h.serverKeyAPIURL + ServerKeyQueryPublicKeyPath + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} diff --git a/serverkeyapi/internal/api.go b/serverkeyapi/internal/api.go new file mode 100644 index 000000000..cbae59d9a --- /dev/null +++ b/serverkeyapi/internal/api.go @@ -0,0 +1,76 @@ +package internal + +import ( + "context" + "fmt" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/serverkeyapi/api" + "github.com/matrix-org/gomatrixserverlib" +) + +type ServerKeyAPI struct { + api.ServerKeyInternalAPI + + ImmutableCache caching.ImmutableCache + OurKeyRing gomatrixserverlib.KeyRing + FedClient *gomatrixserverlib.FederationClient +} + +func (s *ServerKeyAPI) KeyRing() *gomatrixserverlib.KeyRing { + // Return a real keyring - one that has the real database and real + // fetchers. + return &s.OurKeyRing +} + +func (s *ServerKeyAPI) StoreKeys( + ctx context.Context, + results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, +) error { + // Store any keys that we were given in our database. + return s.OurKeyRing.KeyDatabase.StoreKeys(ctx, results) +} + +func (s *ServerKeyAPI) FetchKeys( + ctx context.Context, + requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, +) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { + results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{} + // First consult our local database and see if we have the requested + // keys. These might come from a cache, depending on the database + // implementation used. + if dbResults, err := s.OurKeyRing.KeyDatabase.FetchKeys(ctx, requests); err == nil { + // We successfully got some keys. Add them to the results and + // remove them from the request list. + for req, res := range dbResults { + results[req] = res + delete(requests, req) + } + } + // For any key requests that we still have outstanding, next try to + // fetch them directly. We'll go through each of the key fetchers to + // ask for the remaining keys. + for _, fetcher := range s.OurKeyRing.KeyFetchers { + if len(requests) == 0 { + break + } + if fetcherResults, err := fetcher.FetchKeys(ctx, requests); err == nil { + // We successfully got some keys. Add them to the results and + // remove them from the request list. + for req, res := range fetcherResults { + results[req] = res + delete(requests, req) + } + } + } + // If we failed to fetch any keys then we should report an error. + if len(requests) > 0 { + return results, fmt.Errorf("server key API failed to fetch %d keys", len(requests)) + } + // Return the keys. + return results, nil +} + +func (s *ServerKeyAPI) FetcherName() string { + return s.OurKeyRing.KeyDatabase.FetcherName() +} diff --git a/serverkeyapi/internal/http.go b/serverkeyapi/internal/http.go new file mode 100644 index 000000000..303275712 --- /dev/null +++ b/serverkeyapi/internal/http.go @@ -0,0 +1,60 @@ +package internal + +import ( + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/serverkeyapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +func (s *ServerKeyAPI) SetupHTTP(internalAPIMux *mux.Router) { + internalAPIMux.Handle(api.ServerKeyQueryPublicKeyPath, + internal.MakeInternalAPI("queryPublicKeys", func(req *http.Request) util.JSONResponse { + result := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{} + request := api.QueryPublicKeysRequest{} + response := api.QueryPublicKeysResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + lookup := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) + for req, timestamp := range request.Requests { + if res, ok := s.ImmutableCache.GetServerKey(req); ok { + result[req] = res + continue + } + lookup[req] = timestamp + } + keys, err := s.FetchKeys(req.Context(), lookup) + if err != nil { + return util.ErrorResponse(err) + } + for req, res := range keys { + result[req] = res + } + response.Results = result + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(api.ServerKeyInputPublicKeyPath, + internal.MakeInternalAPI("inputPublicKeys", func(req *http.Request) util.JSONResponse { + request := api.InputPublicKeysRequest{} + response := api.InputPublicKeysResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + store := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) + for req, res := range request.Keys { + store[req] = res + s.ImmutableCache.StoreServerKey(req, res) + } + if err := s.StoreKeys(req.Context(), store); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) +} diff --git a/serverkeyapi/serverkeyapi.go b/serverkeyapi/serverkeyapi.go new file mode 100644 index 000000000..61d807d64 --- /dev/null +++ b/serverkeyapi/serverkeyapi.go @@ -0,0 +1,83 @@ +package serverkeyapi + +import ( + "crypto/ed25519" + "encoding/base64" + + "github.com/matrix-org/dendrite/internal/basecomponent" + "github.com/matrix-org/dendrite/serverkeyapi/api" + "github.com/matrix-org/dendrite/serverkeyapi/internal" + "github.com/matrix-org/dendrite/serverkeyapi/storage" + "github.com/matrix-org/dendrite/serverkeyapi/storage/cache" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +func SetupServerKeyAPIComponent( + base *basecomponent.BaseDendrite, + fedClient *gomatrixserverlib.FederationClient, +) api.ServerKeyInternalAPI { + innerDB, err := storage.NewDatabase( + string(base.Cfg.Database.ServerKey), + base.Cfg.DbProperties(), + base.Cfg.Matrix.ServerName, + base.Cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey), + base.Cfg.Matrix.KeyID, + ) + if err != nil { + logrus.WithError(err).Panicf("failed to connect to server key database") + } + + serverKeyDB, err := cache.NewKeyDatabase(innerDB, base.ImmutableCache) + if err != nil { + logrus.WithError(err).Panicf("failed to set up caching wrapper for server key database") + } + + internalAPI := internal.ServerKeyAPI{ + ImmutableCache: base.ImmutableCache, + FedClient: fedClient, + OurKeyRing: gomatrixserverlib.KeyRing{ + KeyFetchers: []gomatrixserverlib.KeyFetcher{ + &gomatrixserverlib.DirectKeyFetcher{ + Client: fedClient.Client, + }, + }, + KeyDatabase: serverKeyDB, + }, + } + + var b64e = base64.StdEncoding.WithPadding(base64.NoPadding) + for _, ps := range base.Cfg.Matrix.KeyPerspectives { + perspective := &gomatrixserverlib.PerspectiveKeyFetcher{ + PerspectiveServerName: ps.ServerName, + PerspectiveServerKeys: map[gomatrixserverlib.KeyID]ed25519.PublicKey{}, + Client: fedClient.Client, + } + + for _, key := range ps.Keys { + rawkey, err := b64e.DecodeString(key.PublicKey) + if err != nil { + logrus.WithError(err).WithFields(logrus.Fields{ + "server_name": ps.ServerName, + "public_key": key.PublicKey, + }).Warn("Couldn't parse perspective key") + continue + } + perspective.PerspectiveServerKeys[key.KeyID] = rawkey + } + + internalAPI.OurKeyRing.KeyFetchers = append( + internalAPI.OurKeyRing.KeyFetchers, + perspective, + ) + + logrus.WithFields(logrus.Fields{ + "server_name": ps.ServerName, + "num_public_keys": len(ps.Keys), + }).Info("Enabled perspective key fetcher") + } + + internalAPI.SetupHTTP(base.InternalAPIMux) + + return &internalAPI +} diff --git a/internal/keydb/cache/keydb.go b/serverkeyapi/storage/cache/keydb.go similarity index 91% rename from internal/keydb/cache/keydb.go rename to serverkeyapi/storage/cache/keydb.go index 87573ed2b..a0cdb9007 100644 --- a/internal/keydb/cache/keydb.go +++ b/serverkeyapi/storage/cache/keydb.go @@ -5,18 +5,17 @@ import ( "errors" "github.com/matrix-org/dendrite/internal/caching" - "github.com/matrix-org/dendrite/internal/keydb" "github.com/matrix-org/gomatrixserverlib" ) // A Database implements gomatrixserverlib.KeyDatabase and is used to store // the public keys for other matrix servers. type KeyDatabase struct { - inner keydb.Database + inner gomatrixserverlib.KeyDatabase cache caching.ImmutableCache } -func NewKeyDatabase(inner keydb.Database, cache caching.ImmutableCache) (*KeyDatabase, error) { +func NewKeyDatabase(inner gomatrixserverlib.KeyDatabase, cache caching.ImmutableCache) (*KeyDatabase, error) { if inner == nil { return nil, errors.New("inner database can't be nil") } diff --git a/internal/keydb/interface.go b/serverkeyapi/storage/interface.go similarity index 96% rename from internal/keydb/interface.go rename to serverkeyapi/storage/interface.go index c9a20fdd9..3a67ac55a 100644 --- a/internal/keydb/interface.go +++ b/serverkeyapi/storage/interface.go @@ -1,4 +1,4 @@ -package keydb +package storage import ( "context" diff --git a/internal/keydb/keydb.go b/serverkeyapi/storage/keydb.go similarity index 91% rename from internal/keydb/keydb.go rename to serverkeyapi/storage/keydb.go index ad6d56b86..b9389bd63 100644 --- a/internal/keydb/keydb.go +++ b/serverkeyapi/storage/keydb.go @@ -14,7 +14,7 @@ // +build !wasm -package keydb +package storage import ( "net/url" @@ -22,8 +22,8 @@ import ( "golang.org/x/crypto/ed25519" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/keydb/postgres" - "github.com/matrix-org/dendrite/internal/keydb/sqlite3" + "github.com/matrix-org/dendrite/serverkeyapi/storage/postgres" + "github.com/matrix-org/dendrite/serverkeyapi/storage/sqlite3" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/internal/keydb/keydb_wasm.go b/serverkeyapi/storage/keydb_wasm.go similarity index 93% rename from internal/keydb/keydb_wasm.go rename to serverkeyapi/storage/keydb_wasm.go index 349381ee7..3d2ca5505 100644 --- a/internal/keydb/keydb_wasm.go +++ b/serverkeyapi/storage/keydb_wasm.go @@ -12,7 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -package keydb +// +build wasm + +package storage import ( "fmt" @@ -21,7 +23,7 @@ import ( "golang.org/x/crypto/ed25519" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/keydb/sqlite3" + "github.com/matrix-org/dendrite/serverkeyapi/storage/sqlite3" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/internal/keydb/postgres/keydb.go b/serverkeyapi/storage/postgres/keydb.go similarity index 100% rename from internal/keydb/postgres/keydb.go rename to serverkeyapi/storage/postgres/keydb.go diff --git a/internal/keydb/postgres/server_key_table.go b/serverkeyapi/storage/postgres/server_key_table.go similarity index 100% rename from internal/keydb/postgres/server_key_table.go rename to serverkeyapi/storage/postgres/server_key_table.go diff --git a/internal/keydb/sqlite3/keydb.go b/serverkeyapi/storage/sqlite3/keydb.go similarity index 90% rename from internal/keydb/sqlite3/keydb.go rename to serverkeyapi/storage/sqlite3/keydb.go index d1dc61c98..1ad6e7e88 100644 --- a/internal/keydb/sqlite3/keydb.go +++ b/serverkeyapi/storage/sqlite3/keydb.go @@ -17,6 +17,8 @@ package sqlite3 import ( "context" + "errors" + "net/url" "time" "golang.org/x/crypto/ed25519" @@ -44,7 +46,19 @@ func NewDatabase( serverKey ed25519.PublicKey, serverKeyID gomatrixserverlib.KeyID, ) (*Database, error) { - db, err := sqlutil.Open(internal.SQLiteDriverName(), dataSourceName, nil) + uri, err := url.Parse(dataSourceName) + if err != nil { + return nil, err + } + var cs string + if uri.Opaque != "" { // file:filename.db + cs = uri.Opaque + } else if uri.Path != "" { // file:///path/to/filename.db + cs = uri.Path + } else { + return nil, errors.New("no filename or path in connect string") + } + db, err := sqlutil.Open(internal.SQLiteDriverName(), cs, nil) if err != nil { return nil, err } diff --git a/internal/keydb/sqlite3/server_key_table.go b/serverkeyapi/storage/sqlite3/server_key_table.go similarity index 100% rename from internal/keydb/sqlite3/server_key_table.go rename to serverkeyapi/storage/sqlite3/server_key_table.go diff --git a/syncapi/api/query.go b/syncapi/api/query.go index 6585bc093..1da05fcc8 100644 --- a/syncapi/api/query.go +++ b/syncapi/api/query.go @@ -24,10 +24,10 @@ import ( ) const ( - SyncAPIQuerySyncPath = "/api/syncapi/querySync" - SyncAPIQueryStatePath = "/api/syncapi/queryState" - SyncAPIQueryStateTypePath = "/api/syncapi/queryStateType" - SyncAPIQueryMessagesPath = "/api/syncapi/queryMessages" + SyncAPIQuerySyncPath = "/syncapi/querySync" + SyncAPIQueryStatePath = "/syncapi/queryState" + SyncAPIQueryStateTypePath = "/syncapi/queryStateType" + SyncAPIQueryMessagesPath = "/syncapi/queryMessages" ) func NewSyncQueryAPIHTTP(syncapiURL string, httpClient *http.Client) SyncQueryAPI { diff --git a/syncapi/consumers/eduserver_sendtodevice.go b/syncapi/consumers/eduserver_sendtodevice.go new file mode 100644 index 000000000..487018031 --- /dev/null +++ b/syncapi/consumers/eduserver_sendtodevice.go @@ -0,0 +1,113 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package consumers + +import ( + "context" + "encoding/json" + + "github.com/Shopify/sarama" + "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/sync" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + log "github.com/sirupsen/logrus" +) + +// OutputSendToDeviceEventConsumer consumes events that originated in the EDU server. +type OutputSendToDeviceEventConsumer struct { + sendToDeviceConsumer *internal.ContinualConsumer + db storage.Database + serverName gomatrixserverlib.ServerName // our server name + notifier *sync.Notifier +} + +// NewOutputSendToDeviceEventConsumer creates a new OutputSendToDeviceEventConsumer. +// Call Start() to begin consuming from the EDU server. +func NewOutputSendToDeviceEventConsumer( + cfg *config.Dendrite, + kafkaConsumer sarama.Consumer, + n *sync.Notifier, + store storage.Database, +) *OutputSendToDeviceEventConsumer { + + consumer := internal.ContinualConsumer{ + Topic: string(cfg.Kafka.Topics.OutputSendToDeviceEvent), + Consumer: kafkaConsumer, + PartitionStore: store, + } + + s := &OutputSendToDeviceEventConsumer{ + sendToDeviceConsumer: &consumer, + db: store, + serverName: cfg.Matrix.ServerName, + notifier: n, + } + + consumer.ProcessMessage = s.onMessage + + return s +} + +// Start consuming from EDU api +func (s *OutputSendToDeviceEventConsumer) Start() error { + return s.sendToDeviceConsumer.Start() +} + +func (s *OutputSendToDeviceEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { + var output api.OutputSendToDeviceEvent + if err := json.Unmarshal(msg.Value, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("EDU server output log: message parse failure") + return err + } + + _, domain, err := gomatrixserverlib.SplitID('@', output.UserID) + if err != nil { + return err + } + if domain != s.serverName { + return nil + } + + util.GetLogger(context.TODO()).WithFields(log.Fields{ + "sender": output.Sender, + "user_id": output.UserID, + "device_id": output.DeviceID, + "event_type": output.Type, + }).Info("sync API received send-to-device event from EDU server") + + streamPos := s.db.AddSendToDevice() + + _, err = s.db.StoreNewSendForDeviceMessage( + context.TODO(), streamPos, output.UserID, output.DeviceID, output.SendToDeviceEvent, + ) + if err != nil { + log.WithError(err).Errorf("failed to store send-to-device message") + return err + } + + s.notifier.OnNewSendToDevice( + output.UserID, + []string{output.DeviceID}, + types.NewStreamToken(0, streamPos), + ) + + return nil +} diff --git a/syncapi/consumers/eduserver.go b/syncapi/consumers/eduserver_typing.go similarity index 100% rename from syncapi/consumers/eduserver.go rename to syncapi/consumers/eduserver_typing.go diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index 6300b17bb..c876164d7 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -30,7 +30,7 @@ import ( "github.com/matrix-org/util" ) -const pathPrefixR0 = "/_matrix/client/r0" +const pathPrefixR0 = "/client/r0" // Setup configures the given mux with sync-server listeners // @@ -38,12 +38,12 @@ const pathPrefixR0 = "/_matrix/client/r0" // applied: // nolint: gocyclo func Setup( - apiMux *mux.Router, srp *sync.RequestPool, syncDB storage.Database, + publicAPIMux *mux.Router, srp *sync.RequestPool, syncDB storage.Database, deviceDB devices.Database, federation *gomatrixserverlib.FederationClient, rsAPI api.RoomserverInternalAPI, cfg *config.Dendrite, ) { - r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter() + r0mux := publicAPIMux.PathPrefix(pathPrefixR0).Subrouter() authData := auth.Data{ AccountDB: nil, diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 3e9ff0533..566e5d589 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -55,10 +55,12 @@ type Database interface { // sync response for the given user. Events returned will include any client // transaction IDs associated with the given device. These transaction IDs come // from when the device sent the event via an API that included a transaction - // ID. - IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error) - // CompleteSync returns a complete /sync API response for the given user. - CompleteSync(ctx context.Context, userID string, numRecentEventsPerRoom int) (*types.Response, error) + // ID. A response object must be provided for IncrementaSync to populate - it + // will not create one. + IncrementalSync(ctx context.Context, res *types.Response, device authtypes.Device, fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error) + // CompleteSync returns a complete /sync API response for the given user. A response object + // must be provided for CompleteSync to populate - it will not create one. + CompleteSync(ctx context.Context, res *types.Response, device authtypes.Device, numRecentEventsPerRoom int) (*types.Response, error) // GetAccountDataInRange returns all account data for a given user inserted or // updated between two given positions // Returns a map following the format data[roomID] = []dataTypes @@ -104,4 +106,26 @@ type Database interface { StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.HeaderedEvent // SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) + // AddSendToDevice increases the EDU position in the cache and returns the stream position. + AddSendToDevice() types.StreamPosition + // SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns three lists: + // - "events": a list of send-to-device events that should be included in the sync + // - "changes": a list of send-to-device events that should be updated in the database by + // CleanSendToDeviceUpdates + // - "deletions": a list of send-to-device events which have been confirmed as sent and + // can be deleted altogether by CleanSendToDeviceUpdates + // The token supplied should be the current requested sync token, e.g. from the "since" + // parameter. + SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) (events []types.SendToDeviceEvent, changes []types.SendToDeviceNID, deletions []types.SendToDeviceNID, err error) + // StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device. + StoreNewSendForDeviceMessage(ctx context.Context, streamPos types.StreamPosition, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error) + // CleanSendToDeviceUpdates will update or remove any send-to-device updates based on the + // result to a previous call to SendDeviceUpdatesForSync. This is separate as it allows + // SendToDeviceUpdatesForSync to be called multiple times if needed (e.g. before and after + // starting to wait for an incremental sync with timeout). + // The token supplied should be the current requested sync token, e.g. from the "since" + // parameter. + CleanSendToDeviceUpdates(ctx context.Context, toUpdate, toDelete []types.SendToDeviceNID, token types.StreamingToken) (err error) + // SendToDeviceUpdatesWaiting returns true if there are send-to-device updates waiting to be sent. + SendToDeviceUpdatesWaiting(ctx context.Context, userID, deviceID string) (bool, error) } diff --git a/syncapi/storage/postgres/send_to_device_table.go b/syncapi/storage/postgres/send_to_device_table.go new file mode 100644 index 000000000..335a05ef1 --- /dev/null +++ b/syncapi/storage/postgres/send_to_device_table.go @@ -0,0 +1,171 @@ +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" +) + +const sendToDeviceSchema = ` +CREATE SEQUENCE IF NOT EXISTS syncapi_send_to_device_id; + +-- Stores send-to-device messages. +CREATE TABLE IF NOT EXISTS syncapi_send_to_device ( + -- The ID that uniquely identifies this message. + id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_send_to_device_id'), + -- The user ID to send the message to. + user_id TEXT NOT NULL, + -- The device ID to send the message to. + device_id TEXT NOT NULL, + -- The event content JSON. + content TEXT NOT NULL, + -- The token that was supplied to the /sync at the time that this + -- message was included in a sync response, or NULL if we haven't + -- included it in a /sync response yet. + sent_by_token TEXT +); +` + +const insertSendToDeviceMessageSQL = ` + INSERT INTO syncapi_send_to_device (user_id, device_id, content) + VALUES ($1, $2, $3) +` + +const countSendToDeviceMessagesSQL = ` + SELECT COUNT(*) + FROM syncapi_send_to_device + WHERE user_id = $1 AND device_id = $2 +` + +const selectSendToDeviceMessagesSQL = ` + SELECT id, user_id, device_id, content, sent_by_token + FROM syncapi_send_to_device + WHERE user_id = $1 AND device_id = $2 + ORDER BY id DESC +` + +const updateSentSendToDeviceMessagesSQL = ` + UPDATE syncapi_send_to_device SET sent_by_token = $1 + WHERE id = ANY($2) +` + +const deleteSendToDeviceMessagesSQL = ` + DELETE FROM syncapi_send_to_device WHERE id = ANY($1) +` + +type sendToDeviceStatements struct { + insertSendToDeviceMessageStmt *sql.Stmt + countSendToDeviceMessagesStmt *sql.Stmt + selectSendToDeviceMessagesStmt *sql.Stmt + updateSentSendToDeviceMessagesStmt *sql.Stmt + deleteSendToDeviceMessagesStmt *sql.Stmt +} + +func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { + s := &sendToDeviceStatements{} + _, err := db.Exec(sendToDeviceSchema) + if err != nil { + return nil, err + } + if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil { + return nil, err + } + if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil { + return nil, err + } + if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil { + return nil, err + } + if s.updateSentSendToDeviceMessagesStmt, err = db.Prepare(updateSentSendToDeviceMessagesSQL); err != nil { + return nil, err + } + if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *sendToDeviceStatements) InsertSendToDeviceMessage( + ctx context.Context, txn *sql.Tx, userID, deviceID, content string, +) (err error) { + _, err = internal.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) + return +} + +func (s *sendToDeviceStatements) CountSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, userID, deviceID string, +) (count int, err error) { + row := internal.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID) + if err = row.Scan(&count); err != nil { + return + } + return count, nil +} + +func (s *sendToDeviceStatements) SelectSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, userID, deviceID string, +) (events []types.SendToDeviceEvent, err error) { + rows, err := internal.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed") + + for rows.Next() { + var id types.SendToDeviceNID + var userID, deviceID, content string + var sentByToken *string + if err = rows.Scan(&id, &userID, &deviceID, &content, &sentByToken); err != nil { + return + } + event := types.SendToDeviceEvent{ + ID: id, + UserID: userID, + DeviceID: deviceID, + } + if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil { + return + } + if sentByToken != nil { + if token, err := types.NewStreamTokenFromString(*sentByToken); err == nil { + event.SentByToken = &token + } + } + events = append(events, event) + } + + return events, rows.Err() +} + +func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID, +) (err error) { + _, err = txn.Stmt(s.updateSentSendToDeviceMessagesStmt).ExecContext(ctx, token, pq.Array(nids)) + return +} + +func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID, +) (err error) { + _, err = txn.Stmt(s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, pq.Array(nids)) + return +} diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index dc73350a2..8a8f964a5 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -69,6 +69,10 @@ func NewDatabase(dbDataSourceName string, dbProperties internal.DbProperties) (* if err != nil { return nil, err } + sendToDevice, err := NewPostgresSendToDeviceTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ DB: d.db, Invites: invites, @@ -77,6 +81,8 @@ func NewDatabase(dbDataSourceName string, dbProperties internal.DbProperties) (* Topology: topology, CurrentRoomState: currState, BackwardExtremities: backwardExtremities, + SendToDevice: sendToDevice, + SendToDeviceWriter: internal.NewTransactionWriter(), EDUCache: cache.New(), } return &d, nil diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 145a3dbf3..497c043a6 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -1,3 +1,17 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package shared import ( @@ -27,6 +41,8 @@ type Database struct { Topology tables.Topology CurrentRoomState tables.CurrentRoomState BackwardExtremities tables.BackwardsExtremities + SendToDevice tables.SendToDevice + SendToDeviceWriter *internal.TransactionWriter EDUCache *cache.EDUCache } @@ -89,6 +105,10 @@ func (d *Database) RemoveTypingUser( return types.StreamPosition(d.EDUCache.RemoveUser(userID, roomID)) } +func (d *Database) AddSendToDevice() types.StreamPosition { + return types.StreamPosition(d.EDUCache.AddSendToDeviceMessage()) +} + func (d *Database) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) { d.EDUCache.SetTimeoutCallback(fn) } @@ -528,14 +548,14 @@ func (d *Database) addEDUDeltaToResponse( } func (d *Database) IncrementalSync( - ctx context.Context, + ctx context.Context, res *types.Response, device authtypes.Device, fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool, ) (*types.Response, error) { nextBatchPos := fromPos.WithUpdates(toPos) - res := types.NewResponse(nextBatchPos) + res.NextBatch = nextBatchPos.String() var joinedRoomIDs []string var err error @@ -568,12 +588,12 @@ func (d *Database) IncrementalSync( // getResponseWithPDUsForCompleteSync creates a response and adds all PDUs needed // to it. It returns toPos and joinedRoomIDs for use of adding EDUs. +// nolint:nakedret func (d *Database) getResponseWithPDUsForCompleteSync( - ctx context.Context, + ctx context.Context, res *types.Response, userID string, numRecentEventsPerRoom int, ) ( - res *types.Response, toPos types.StreamingToken, joinedRoomIDs []string, err error, @@ -604,7 +624,7 @@ func (d *Database) getResponseWithPDUsForCompleteSync( To: toPos.PDUPosition(), } - res = types.NewResponse(toPos) + res.NextBatch = toPos.String() // Extract room state and recent events for all rooms the user is joined to. joinedRoomIDs, err = d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) @@ -662,14 +682,15 @@ func (d *Database) getResponseWithPDUsForCompleteSync( } succeeded = true - return res, toPos, joinedRoomIDs, err + return //res, toPos, joinedRoomIDs, err } func (d *Database) CompleteSync( - ctx context.Context, userID string, numRecentEventsPerRoom int, + ctx context.Context, res *types.Response, + device authtypes.Device, numRecentEventsPerRoom int, ) (*types.Response, error) { - res, toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync( - ctx, userID, numRecentEventsPerRoom, + toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync( + ctx, res, device.UserID, numRecentEventsPerRoom, ) if err != nil { return nil, err @@ -1028,6 +1049,115 @@ func (d *Database) currentStateStreamEventsForRoom( return s, nil } +func (d *Database) SendToDeviceUpdatesWaiting( + ctx context.Context, userID, deviceID string, +) (bool, error) { + count, err := d.SendToDevice.CountSendToDeviceMessages(ctx, nil, userID, deviceID) + if err != nil { + return false, err + } + return count > 0, nil +} + +func (d *Database) AddSendToDeviceEvent( + ctx context.Context, txn *sql.Tx, + userID, deviceID, content string, +) error { + return d.SendToDevice.InsertSendToDeviceMessage( + ctx, txn, userID, deviceID, content, + ) +} + +func (d *Database) StoreNewSendForDeviceMessage( + ctx context.Context, streamPos types.StreamPosition, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent, +) (types.StreamPosition, error) { + j, err := json.Marshal(event) + if err != nil { + return streamPos, err + } + // Delegate the database write task to the SendToDeviceWriter. It'll guarantee + // that we don't lock the table for writes in more than one place. + err = d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) error { + return d.AddSendToDeviceEvent( + ctx, txn, userID, deviceID, string(j), + ) + }) + if err != nil { + return streamPos, err + } + return streamPos, nil +} + +func (d *Database) SendToDeviceUpdatesForSync( + ctx context.Context, + userID, deviceID string, + token types.StreamingToken, +) ([]types.SendToDeviceEvent, []types.SendToDeviceNID, []types.SendToDeviceNID, error) { + // First of all, get our send-to-device updates for this user. + events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID) + if err != nil { + return nil, nil, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err) + } + + // If there's nothing to do then stop here. + if len(events) == 0 { + return nil, nil, nil, nil + } + + // Work out whether we need to update any of the database entries. + toReturn := []types.SendToDeviceEvent{} + toUpdate := []types.SendToDeviceNID{} + toDelete := []types.SendToDeviceNID{} + for _, event := range events { + if event.SentByToken == nil { + // If the event has no sent-by token yet then we haven't attempted to send + // it. Record the current requested sync token in the database. + toUpdate = append(toUpdate, event.ID) + toReturn = append(toReturn, event) + event.SentByToken = &token + } else if token.IsAfter(*event.SentByToken) { + // The event had a sync token, therefore we've sent it before. The current + // sync token is now after the stored one so we can assume that the client + // successfully completed the previous sync (it would re-request it otherwise) + // so we can remove the entry from the database. + toDelete = append(toDelete, event.ID) + } else { + // It looks like the sync is being re-requested, maybe it timed out or + // failed. Re-send any that should have been acknowledged by now. + toReturn = append(toReturn, event) + } + } + + return toReturn, toUpdate, toDelete, nil +} + +func (d *Database) CleanSendToDeviceUpdates( + ctx context.Context, + toUpdate, toDelete []types.SendToDeviceNID, + token types.StreamingToken, +) (err error) { + if len(toUpdate) == 0 && len(toDelete) == 0 { + return nil + } + // If we need to write to the database then we'll ask the SendToDeviceWriter to + // do that for us. It'll guarantee that we don't lock the table for writes in + // more than one place. + err = d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) error { + // Delete any send-to-device messages marked for deletion. + if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil { + return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e) + } + + // Now update any outstanding send-to-device messages with the new sync token. + if e := d.SendToDevice.UpdateSentSendToDeviceMessages(ctx, txn, token.String(), toUpdate); e != nil { + return fmt.Errorf("d.SendToDevice.UpdateSentSendToDeviceMessages: %w", err) + } + + return nil + }) + return +} + // There may be some overlap where events in stateEvents are already in recentEvents, so filter // them out so we don't include them twice in the /sync response. They should be in recentEvents // only, so clients get to the correct state once they have rolled forward. diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go new file mode 100644 index 000000000..0d03f23ef --- /dev/null +++ b/syncapi/storage/sqlite3/send_to_device_table.go @@ -0,0 +1,172 @@ +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" +) + +const sendToDeviceSchema = ` +-- Stores send-to-device messages. +CREATE TABLE IF NOT EXISTS syncapi_send_to_device ( + -- The ID that uniquely identifies this message. + id INTEGER PRIMARY KEY AUTOINCREMENT, + -- The user ID to send the message to. + user_id TEXT NOT NULL, + -- The device ID to send the message to. + device_id TEXT NOT NULL, + -- The event content JSON. + content TEXT NOT NULL, + -- The token that was supplied to the /sync at the time that this + -- message was included in a sync response, or NULL if we haven't + -- included it in a /sync response yet. + sent_by_token TEXT +); +` + +const insertSendToDeviceMessageSQL = ` + INSERT INTO syncapi_send_to_device (user_id, device_id, content) + VALUES ($1, $2, $3) +` + +const countSendToDeviceMessagesSQL = ` + SELECT COUNT(*) + FROM syncapi_send_to_device + WHERE user_id = $1 AND device_id = $2 +` + +const selectSendToDeviceMessagesSQL = ` + SELECT id, user_id, device_id, content, sent_by_token + FROM syncapi_send_to_device + WHERE user_id = $1 AND device_id = $2 + ORDER BY id DESC +` + +const updateSentSendToDeviceMessagesSQL = ` + UPDATE syncapi_send_to_device SET sent_by_token = $1 + WHERE id IN ($2) +` + +const deleteSendToDeviceMessagesSQL = ` + DELETE FROM syncapi_send_to_device WHERE id IN ($1) +` + +type sendToDeviceStatements struct { + insertSendToDeviceMessageStmt *sql.Stmt + selectSendToDeviceMessagesStmt *sql.Stmt + countSendToDeviceMessagesStmt *sql.Stmt +} + +func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { + s := &sendToDeviceStatements{} + _, err := db.Exec(sendToDeviceSchema) + if err != nil { + return nil, err + } + if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil { + return nil, err + } + if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil { + return nil, err + } + if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *sendToDeviceStatements) InsertSendToDeviceMessage( + ctx context.Context, txn *sql.Tx, userID, deviceID, content string, +) (err error) { + _, err = internal.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) + return +} + +func (s *sendToDeviceStatements) CountSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, userID, deviceID string, +) (count int, err error) { + row := internal.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID) + if err = row.Scan(&count); err != nil { + return + } + return count, nil +} + +func (s *sendToDeviceStatements) SelectSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, userID, deviceID string, +) (events []types.SendToDeviceEvent, err error) { + rows, err := internal.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed") + + for rows.Next() { + var id types.SendToDeviceNID + var userID, deviceID, content string + var sentByToken *string + if err = rows.Scan(&id, &userID, &deviceID, &content, &sentByToken); err != nil { + return + } + event := types.SendToDeviceEvent{ + ID: id, + UserID: userID, + DeviceID: deviceID, + } + if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil { + return + } + if sentByToken != nil { + if token, err := types.NewStreamTokenFromString(*sentByToken); err == nil { + event.SentByToken = &token + } + } + events = append(events, event) + } + + return events, rows.Err() +} + +func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID, +) (err error) { + query := strings.Replace(updateSentSendToDeviceMessagesSQL, "($2)", internal.QueryVariadic(1+len(nids)), 1) + params := make([]interface{}, 1+len(nids)) + params[0] = token + for k, v := range nids { + params[k+1] = v + } + _, err = txn.ExecContext(ctx, query, params...) + return +} + +func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID, +) (err error) { + query := strings.Replace(deleteSendToDeviceMessagesSQL, "($1)", internal.QueryVariadic(len(nids)), 1) + params := make([]interface{}, 1+len(nids)) + for k, v := range nids { + params[k] = v + } + _, err = txn.ExecContext(ctx, query, params...) + return +} diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 8ab1d4040..5ba07617e 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -95,6 +95,10 @@ func (d *SyncServerDatasource) prepare() (err error) { if err != nil { return err } + sendToDevice, err := NewSqliteSendToDeviceTable(d.db) + if err != nil { + return err + } d.Database = shared.Database{ DB: d.db, Invites: invites, @@ -103,6 +107,8 @@ func (d *SyncServerDatasource) prepare() (err error) { BackwardExtremities: bwExtrem, CurrentRoomState: roomState, Topology: topology, + SendToDevice: sendToDevice, + SendToDeviceWriter: internal.NewTransactionWriter(), EDUCache: cache.New(), } return nil diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 6b0458527..4661ede4d 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -3,6 +3,7 @@ package storage_test import ( "context" "crypto/ed25519" + "encoding/json" "fmt" "testing" "time" @@ -157,7 +158,8 @@ func TestSyncResponse(t *testing.T) { from := types.NewStreamToken( // pretend we are at the penultimate event positions[len(positions)-2], types.StreamPosition(0), ) - return db.IncrementalSync(ctx, testUserDeviceA, from, latest, 5, false) + res := types.NewResponse() + return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) }, WantTimeline: events[len(events)-1:], }, @@ -169,8 +171,9 @@ func TestSyncResponse(t *testing.T) { from := types.NewStreamToken( // pretend we are 10 events behind positions[len(positions)-11], types.StreamPosition(0), ) + res := types.NewResponse() // limit is set to 5 - return db.IncrementalSync(ctx, testUserDeviceA, from, latest, 5, false) + return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) }, // want the last 5 events, NOT the last 10. WantTimeline: events[len(events)-5:], @@ -180,8 +183,9 @@ func TestSyncResponse(t *testing.T) { { Name: "CompleteSync limited", DoSync: func() (*types.Response, error) { + res := types.NewResponse() // limit set to 5 - return db.CompleteSync(ctx, testUserIDA, 5) + return db.CompleteSync(ctx, res, testUserDeviceA, 5) }, // want the last 5 events WantTimeline: events[len(events)-5:], @@ -193,7 +197,8 @@ func TestSyncResponse(t *testing.T) { { Name: "CompleteSync", DoSync: func() (*types.Response, error) { - return db.CompleteSync(ctx, testUserIDA, len(events)+1) + res := types.NewResponse() + return db.CompleteSync(ctx, res, testUserDeviceA, len(events)+1) }, WantTimeline: events, // We want no state at all as that field in /sync is the delta between the token (beginning of time) @@ -234,7 +239,8 @@ func TestGetEventsInRangeWithPrevBatch(t *testing.T) { positions[len(positions)-2], types.StreamPosition(0), ) - res, err := db.IncrementalSync(ctx, testUserDeviceA, from, latest, 5, false) + res := types.NewResponse() + res, err = db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) if err != nil { t.Fatalf("failed to IncrementalSync with latest token") } @@ -512,6 +518,89 @@ func TestGetEventsInRangeWithEventsInsertedLikeBackfill(t *testing.T) { } } +func TestSendToDeviceBehaviour(t *testing.T) { + //t.Parallel() + db := MustCreateDatabase(t) + + // At this point there should be no messages. We haven't sent anything + // yet. + events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0)) + if err != nil { + t.Fatal(err) + } + if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 { + t.Fatal("first call should have no updates") + } + err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, 0)) + if err != nil { + return + } + + // Try sending a message. + streamPos, err := db.StoreNewSendForDeviceMessage(ctx, types.StreamPosition(0), "alice", "one", gomatrixserverlib.SendToDeviceEvent{ + Sender: "bob", + Type: "m.type", + Content: json.RawMessage("{}"), + }) + if err != nil { + t.Fatal(err) + } + + // At this point we should get exactly one message. We're sending the sync position + // that we were given from the update and the send-to-device update will be updated + // in the database to reflect that this was the sync position we sent the message at. + events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos)) + if err != nil { + t.Fatal(err) + } + if len(events) != 1 || len(updates) != 1 || len(deletions) != 0 { + t.Fatal("second call should have one update") + } + err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos)) + if err != nil { + return + } + + // At this point we should still have one message because we haven't progressed the + // sync position yet. This is equivalent to the client failing to /sync and retrying + // with the same position. + events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos)) + if err != nil { + t.Fatal(err) + } + if len(events) != 1 || len(updates) != 0 || len(deletions) != 0 { + t.Fatal("third call should have one update still") + } + err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos)) + if err != nil { + return + } + + // At this point we should now have no updates, because we've progressed the sync + // position. Therefore the update from before will not be sent again. + events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+1)) + if err != nil { + t.Fatal(err) + } + if len(events) != 0 || len(updates) != 0 || len(deletions) != 1 { + t.Fatal("fourth call should have no updates") + } + err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos+1)) + if err != nil { + return + } + + // At this point we should still have no updates, because no new updates have been + // sent. + events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+2)) + if err != nil { + t.Fatal(err) + } + if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 { + t.Fatal("fifth call should have no updates") + } +} + func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatrixserverlib.ClientEvent, wants []gomatrixserverlib.HeaderedEvent) { if len(gots) != len(wants) { t.Fatalf("%s response returned %d events, want %d", msg, len(gots), len(wants)) diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index bc3b6941a..0b7d15951 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -1,3 +1,17 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package tables import ( @@ -94,3 +108,28 @@ type BackwardsExtremities interface { // DeleteBackwardExtremity removes a backwards extremity for a room, if one existed. DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error) } + +// SendToDevice tracks send-to-device messages which are sent to individual +// clients. Each message gets inserted into this table at the point that we +// receive it from the EDU server. +// +// We're supposed to try and do our best to deliver send-to-device messages +// once, but the only way that we can really guarantee that they have been +// delivered is if the client successfully requests the next sync as given +// in the next_batch. Each time the device syncs, we will request all of the +// updates that either haven't been sent yet, along with all updates that we +// *have* sent but we haven't confirmed to have been received yet. If it's the +// first time we're sending a given update then we update the table to say +// what the "since" parameter was when we tried to send it. +// +// When the client syncs again, if their "since" parameter is *later* than +// the recorded one, we drop the entry from the DB as it's "sent". If the +// sync parameter isn't later then we will keep including the updates in the +// sync response, as the client is seemingly trying to repeat the same /sync. +type SendToDevice interface { + InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (err error) + SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (events []types.SendToDeviceEvent, err error) + UpdateSentSendToDeviceMessages(ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID) (err error) + DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error) + CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error) +} diff --git a/syncapi/sync/notifier.go b/syncapi/sync/notifier.go index b3ed5cd03..325e75351 100644 --- a/syncapi/sync/notifier.go +++ b/syncapi/sync/notifier.go @@ -37,8 +37,8 @@ type Notifier struct { streamLock *sync.Mutex // The latest sync position currPos types.StreamingToken - // A map of user_id => UserStream which can be used to wake a given user's /sync request. - userStreams map[string]*UserStream + // A map of user_id => device_id => UserStream which can be used to wake a given user's /sync request. + userDeviceStreams map[string]map[string]*UserDeviceStream // The last time we cleaned out stale entries from the userStreams map lastCleanUpTime time.Time } @@ -50,7 +50,7 @@ func NewNotifier(pos types.StreamingToken) *Notifier { return &Notifier{ currPos: pos, roomIDToJoinedUsers: make(map[string]userIDSet), - userStreams: make(map[string]*UserStream), + userDeviceStreams: make(map[string]map[string]*UserDeviceStream), streamLock: &sync.Mutex{}, lastCleanUpTime: time.Now(), } @@ -120,10 +120,22 @@ func (n *Notifier) OnNewEvent( } } +func (n *Notifier) OnNewSendToDevice( + userID string, deviceIDs []string, + posUpdate types.StreamingToken, +) { + n.streamLock.Lock() + defer n.streamLock.Unlock() + latestPos := n.currPos.WithUpdates(posUpdate) + n.currPos = latestPos + + n.wakeupUserDevice(userID, deviceIDs, latestPos) +} + // GetListener returns a UserStreamListener that can be used to wait for // updates for a user. Must be closed. // notify for anything before sincePos -func (n *Notifier) GetListener(req syncRequest) UserStreamListener { +func (n *Notifier) GetListener(req syncRequest) UserDeviceStreamListener { // Do what synapse does: https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/notifier.py#L298 // - Bucket request into a lookup map keyed off a list of joined room IDs and separately a user ID // - Incoming events wake requests for a matching room ID @@ -137,7 +149,7 @@ func (n *Notifier) GetListener(req syncRequest) UserStreamListener { n.removeEmptyUserStreams() - return n.fetchUserStream(req.device.UserID, true).GetListener(req.ctx) + return n.fetchUserDeviceStream(req.device.UserID, req.device.ID, true).GetListener(req.ctx) } // Load the membership states required to notify users correctly. @@ -173,27 +185,69 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) { } } +// wakeupUsers will wake up the sync strems for all of the devices for all of the +// specified user IDs. func (n *Notifier) wakeupUsers(userIDs []string, newPos types.StreamingToken) { for _, userID := range userIDs { - stream := n.fetchUserStream(userID, false) - if stream != nil { + for _, stream := range n.fetchUserStreams(userID) { + if stream == nil { + continue + } stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream } } } -// fetchUserStream retrieves a stream unique to the given user. If makeIfNotExists is true, +// wakeupUserDevice will wake up the sync stream for a specific user device. Other +// device streams will be left alone. +// nolint:unused +func (n *Notifier) wakeupUserDevice(userID string, deviceIDs []string, newPos types.StreamingToken) { + for _, deviceID := range deviceIDs { + if stream := n.fetchUserDeviceStream(userID, deviceID, false); stream != nil { + stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream + } + } +} + +// fetchUserDeviceStream retrieves a stream unique to the given device. If makeIfNotExists is true, +// a stream will be made for this device if one doesn't exist and it will be returned. This +// function does not wait for data to be available on the stream. +// NB: Callers should have locked the mutex before calling this function. +func (n *Notifier) fetchUserDeviceStream(userID, deviceID string, makeIfNotExists bool) *UserDeviceStream { + _, ok := n.userDeviceStreams[userID] + if !ok { + if !makeIfNotExists { + return nil + } + n.userDeviceStreams[userID] = map[string]*UserDeviceStream{} + } + stream, ok := n.userDeviceStreams[userID][deviceID] + if !ok { + if !makeIfNotExists { + return nil + } + // TODO: Unbounded growth of streams (1 per user) + if stream = NewUserDeviceStream(userID, deviceID, n.currPos); stream != nil { + n.userDeviceStreams[userID][deviceID] = stream + } + } + return stream +} + +// fetchUserStreams retrieves all streams for the given user. If makeIfNotExists is true, // a stream will be made for this user if one doesn't exist and it will be returned. This // function does not wait for data to be available on the stream. // NB: Callers should have locked the mutex before calling this function. -func (n *Notifier) fetchUserStream(userID string, makeIfNotExists bool) *UserStream { - stream, ok := n.userStreams[userID] - if !ok && makeIfNotExists { - // TODO: Unbounded growth of streams (1 per user) - stream = NewUserStream(userID, n.currPos) - n.userStreams[userID] = stream +func (n *Notifier) fetchUserStreams(userID string) []*UserDeviceStream { + user, ok := n.userDeviceStreams[userID] + if !ok { + return []*UserDeviceStream{} } - return stream + streams := []*UserDeviceStream{} + for _, stream := range user { + streams = append(streams, stream) + } + return streams } // Not thread-safe: must be called on the OnNewEvent goroutine only @@ -236,9 +290,14 @@ func (n *Notifier) removeEmptyUserStreams() { n.lastCleanUpTime = now deleteBefore := now.Add(-5 * time.Minute) - for key, value := range n.userStreams { - if value.TimeOfLastNonEmpty().Before(deleteBefore) { - delete(n.userStreams, key) + for user, byUser := range n.userDeviceStreams { + for device, stream := range byUser { + if stream.TimeOfLastNonEmpty().Before(deleteBefore) { + delete(n.userDeviceStreams[user], device) + } + if len(n.userDeviceStreams[user]) == 0 { + delete(n.userDeviceStreams, user) + } } } } diff --git a/syncapi/sync/notifier_test.go b/syncapi/sync/notifier_test.go index 7d979fcc9..132315573 100644 --- a/syncapi/sync/notifier_test.go +++ b/syncapi/sync/notifier_test.go @@ -41,9 +41,11 @@ var ( ) var ( - roomID = "!test:localhost" - alice = "@alice:localhost" - bob = "@bob:localhost" + roomID = "!test:localhost" + alice = "@alice:localhost" + aliceDev = "alicedevice" + bob = "@bob:localhost" + bobDev = "bobdev" ) func init() { @@ -107,7 +109,7 @@ func mustEqualPositions(t *testing.T, got, want types.StreamingToken) { // Test that the current position is returned if a request is already behind. func TestImmediateNotification(t *testing.T) { n := NewNotifier(syncPositionBefore) - pos, err := waitForEvents(n, newTestSyncRequest(alice, syncPositionVeryOld)) + pos, err := waitForEvents(n, newTestSyncRequest(alice, aliceDev, syncPositionVeryOld)) if err != nil { t.Fatalf("TestImmediateNotification error: %s", err) } @@ -124,7 +126,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) { var wg sync.WaitGroup wg.Add(1) go func() { - pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore)) + pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionBefore)) if err != nil { t.Errorf("TestNewEventAndJoinedToRoom error: %w", err) } @@ -132,7 +134,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) { wg.Done() }() - stream := lockedFetchUserStream(n, bob) + stream := lockedFetchUserStream(n, bob, bobDev) waitForBlocking(stream, 1) n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter) @@ -140,6 +142,43 @@ func TestNewEventAndJoinedToRoom(t *testing.T) { wg.Wait() } +func TestCorrectStream(t *testing.T) { + n := NewNotifier(syncPositionBefore) + stream := lockedFetchUserStream(n, bob, bobDev) + if stream.UserID != bob { + t.Fatalf("expected user %q, got %q", bob, stream.UserID) + } + if stream.DeviceID != bobDev { + t.Fatalf("expected device %q, got %q", bobDev, stream.DeviceID) + } +} + +func TestCorrectStreamWakeup(t *testing.T) { + n := NewNotifier(syncPositionBefore) + awoken := make(chan string) + + streamone := lockedFetchUserStream(n, alice, "one") + streamtwo := lockedFetchUserStream(n, alice, "two") + + go func() { + select { + case <-streamone.signalChannel: + awoken <- "one" + case <-streamtwo.signalChannel: + awoken <- "two" + } + }() + + time.Sleep(1 * time.Second) + + wake := "two" + n.wakeupUserDevice(alice, []string{wake}, syncPositionAfter) + + if result := <-awoken; result != wake { + t.Fatalf("expected to wake %q, got %q", wake, result) + } +} + // Test that an invite unblocks the request func TestNewInviteEventForUser(t *testing.T) { n := NewNotifier(syncPositionBefore) @@ -150,7 +189,7 @@ func TestNewInviteEventForUser(t *testing.T) { var wg sync.WaitGroup wg.Add(1) go func() { - pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore)) + pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionBefore)) if err != nil { t.Errorf("TestNewInviteEventForUser error: %w", err) } @@ -158,7 +197,7 @@ func TestNewInviteEventForUser(t *testing.T) { wg.Done() }() - stream := lockedFetchUserStream(n, bob) + stream := lockedFetchUserStream(n, bob, bobDev) waitForBlocking(stream, 1) n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionAfter) @@ -176,7 +215,7 @@ func TestEDUWakeup(t *testing.T) { var wg sync.WaitGroup wg.Add(1) go func() { - pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionAfter)) + pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionAfter)) if err != nil { t.Errorf("TestNewInviteEventForUser error: %w", err) } @@ -184,7 +223,7 @@ func TestEDUWakeup(t *testing.T) { wg.Done() }() - stream := lockedFetchUserStream(n, bob) + stream := lockedFetchUserStream(n, bob, bobDev) waitForBlocking(stream, 1) n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionNewEDU) @@ -202,7 +241,7 @@ func TestMultipleRequestWakeup(t *testing.T) { var wg sync.WaitGroup wg.Add(3) poll := func() { - pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore)) + pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionBefore)) if err != nil { t.Errorf("TestMultipleRequestWakeup error: %w", err) } @@ -213,7 +252,7 @@ func TestMultipleRequestWakeup(t *testing.T) { go poll() go poll() - stream := lockedFetchUserStream(n, bob) + stream := lockedFetchUserStream(n, bob, bobDev) waitForBlocking(stream, 3) n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter) @@ -240,24 +279,24 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { // Make bob leave the room leaveWG.Add(1) go func() { - pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore)) + pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionBefore)) if err != nil { t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %w", err) } mustEqualPositions(t, pos, syncPositionAfter) leaveWG.Done() }() - bobStream := lockedFetchUserStream(n, bob) + bobStream := lockedFetchUserStream(n, bob, bobDev) waitForBlocking(bobStream, 1) n.OnNewEvent(&bobLeaveEvent, "", nil, syncPositionAfter) leaveWG.Wait() // send an event into the room. Make sure alice gets it. Bob should not. var aliceWG sync.WaitGroup - aliceStream := lockedFetchUserStream(n, alice) + aliceStream := lockedFetchUserStream(n, alice, aliceDev) aliceWG.Add(1) go func() { - pos, err := waitForEvents(n, newTestSyncRequest(alice, syncPositionAfter)) + pos, err := waitForEvents(n, newTestSyncRequest(alice, aliceDev, syncPositionAfter)) if err != nil { t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %w", err) } @@ -267,7 +306,7 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { go func() { // this should timeout with an error (but the main goroutine won't wait for the timeout explicitly) - _, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionAfter)) + _, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionAfter)) if err == nil { t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom expect error but got nil") } @@ -300,7 +339,7 @@ func waitForEvents(n *Notifier, req syncRequest) (types.StreamingToken, error) { } // Wait until something is Wait()ing on the user stream. -func waitForBlocking(s *UserStream, numBlocking uint) { +func waitForBlocking(s *UserDeviceStream, numBlocking uint) { for numBlocking != s.NumWaiting() { // This is horrible but I don't want to add a signalling mechanism JUST for testing. time.Sleep(1 * time.Microsecond) @@ -309,16 +348,19 @@ func waitForBlocking(s *UserStream, numBlocking uint) { // lockedFetchUserStream invokes Notifier.fetchUserStream, respecting Notifier.streamLock. // A new stream is made if it doesn't exist already. -func lockedFetchUserStream(n *Notifier, userID string) *UserStream { +func lockedFetchUserStream(n *Notifier, userID, deviceID string) *UserDeviceStream { n.streamLock.Lock() defer n.streamLock.Unlock() - return n.fetchUserStream(userID, true) + return n.fetchUserDeviceStream(userID, deviceID, true) } -func newTestSyncRequest(userID string, since types.StreamingToken) syncRequest { +func newTestSyncRequest(userID, deviceID string, since types.StreamingToken) syncRequest { return syncRequest{ - device: authtypes.Device{UserID: userID}, + device: authtypes.Device{ + UserID: userID, + ID: deviceID, + }, timeout: 1 * time.Minute, since: &since, wantFullState: false, diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 82ce283b6..8b93cad45 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -1,4 +1,6 @@ // Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,6 +17,7 @@ package sync import ( + "context" "net/http" "time" @@ -47,7 +50,6 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype var syncData *types.Response // Extract values from request - userID := device.UserID syncReq, err := newSyncRequest(req, *device) if err != nil { return util.JSONResponse{ @@ -55,16 +57,18 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype JSON: jsonerror.Unknown(err.Error()), } } + logger := util.GetLogger(req.Context()).WithFields(log.Fields{ - "userID": userID, - "since": syncReq.since, - "timeout": syncReq.timeout, - "limit": syncReq.limit, + "user_id": device.UserID, + "device_id": device.ID, + "since": syncReq.since, + "timeout": syncReq.timeout, + "limit": syncReq.limit, }) currPos := rp.notifier.CurrentPosition() - if shouldReturnImmediately(syncReq) { + if rp.shouldReturnImmediately(syncReq) { syncData, err = rp.currentSyncForUser(*syncReq, currPos) if err != nil { logger.WithError(err).Error("rp.currentSyncForUser failed") @@ -116,7 +120,6 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype // response. This ensures that we don't waste the hard work // of calculating the sync only to get timed out before we // can respond - syncData, err = rp.currentSyncForUser(*syncReq, currPos) if err != nil { logger.WithError(err).Error("rp.currentSyncForUser failed") @@ -134,19 +137,59 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype } func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) { - // TODO: handle ignored users - if req.since == nil { - res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit) - } else { - res, err = rp.db.IncrementalSync(req.ctx, req.device, *req.since, latestPos, req.limit, req.wantFullState) + res = types.NewResponse() + + since := types.NewStreamToken(0, 0) + if req.since != nil { + since = *req.since } + // See if we have any new tasks to do for the send-to-device messaging. + events, updates, deletions, err := rp.db.SendToDeviceUpdatesForSync(req.ctx, req.device.UserID, req.device.ID, since) + if err != nil { + return nil, err + } + + // TODO: handle ignored users + if req.since == nil { + res, err = rp.db.CompleteSync(req.ctx, res, req.device, req.limit) + } else { + res, err = rp.db.IncrementalSync(req.ctx, res, req.device, *req.since, latestPos, req.limit, req.wantFullState) + } if err != nil { return } accountDataFilter := gomatrixserverlib.DefaultEventFilter() // TODO: use filter provided in req instead res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition(), &accountDataFilter) + if err != nil { + return + } + + // Before we return the sync response, make sure that we take action on + // any send-to-device database updates or deletions that we need to do. + // Then add the updates into the sync response. + if len(updates) > 0 || len(deletions) > 0 { + // Handle the updates and deletions in the database. + err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, since) + if err != nil { + return + } + } + if len(events) > 0 { + // Add the updates into the sync response. + for _, event := range events { + res.ToDevice.Events = append(res.ToDevice.Events, event.SendToDeviceEvent) + } + + // Get the next_batch from the sync response and increase the + // EDU counter. + if pos, perr := types.NewStreamTokenFromString(res.NextBatch); perr == nil { + pos.Positions[1]++ + res.NextBatch = pos.String() + } + } + return } @@ -238,6 +281,10 @@ func (rp *RequestPool) appendAccountData( // shouldReturnImmediately returns whether the /sync request is an initial sync, // or timeout=0, or full_state=true, in any of the cases the request should // return immediately. -func shouldReturnImmediately(syncReq *syncRequest) bool { - return syncReq.since == nil || syncReq.timeout == 0 || syncReq.wantFullState +func (rp *RequestPool) shouldReturnImmediately(syncReq *syncRequest) bool { + if syncReq.since == nil || syncReq.timeout == 0 || syncReq.wantFullState { + return true + } + waiting, werr := rp.db.SendToDeviceUpdatesWaiting(context.TODO(), syncReq.device.UserID, syncReq.device.ID) + return werr == nil && waiting } diff --git a/syncapi/sync/userstream.go b/syncapi/sync/userstream.go index b2eafa3dc..ff9a4d003 100644 --- a/syncapi/sync/userstream.go +++ b/syncapi/sync/userstream.go @@ -23,12 +23,13 @@ import ( "github.com/matrix-org/dendrite/syncapi/types" ) -// UserStream represents a communication mechanism between the /sync request goroutine +// UserDeviceStream represents a communication mechanism between the /sync request goroutine // and the underlying sync server goroutines. // Goroutines can get a UserStreamListener to wait for updates, and can Broadcast() // updates. -type UserStream struct { - UserID string +type UserDeviceStream struct { + UserID string + DeviceID string // The lock that protects changes to this struct lock sync.Mutex // Closed when there is an update. @@ -41,18 +42,19 @@ type UserStream struct { numWaiting uint } -// UserStreamListener allows a sync request to wait for updates for a user. -type UserStreamListener struct { - userStream *UserStream +// UserDeviceStreamListener allows a sync request to wait for updates for a user. +type UserDeviceStreamListener struct { + userStream *UserDeviceStream // Whether the stream has been closed hasClosed bool } -// NewUserStream creates a new user stream -func NewUserStream(userID string, currPos types.StreamingToken) *UserStream { - return &UserStream{ +// NewUserDeviceStream creates a new user stream +func NewUserDeviceStream(userID, deviceID string, currPos types.StreamingToken) *UserDeviceStream { + return &UserDeviceStream{ UserID: userID, + DeviceID: deviceID, timeOfLastChannel: time.Now(), pos: currPos, signalChannel: make(chan struct{}), @@ -62,18 +64,18 @@ func NewUserStream(userID string, currPos types.StreamingToken) *UserStream { // GetListener returns UserStreamListener that a sync request can use to wait // for new updates with. // UserStreamListener must be closed -func (s *UserStream) GetListener(ctx context.Context) UserStreamListener { +func (s *UserDeviceStream) GetListener(ctx context.Context) UserDeviceStreamListener { s.lock.Lock() defer s.lock.Unlock() s.numWaiting++ // We decrement when UserStreamListener is closed - listener := UserStreamListener{ + listener := UserDeviceStreamListener{ userStream: s, } // Lets be a bit paranoid here and check that Close() is being called - runtime.SetFinalizer(&listener, func(l *UserStreamListener) { + runtime.SetFinalizer(&listener, func(l *UserDeviceStreamListener) { if !l.hasClosed { l.Close() } @@ -83,7 +85,7 @@ func (s *UserStream) GetListener(ctx context.Context) UserStreamListener { } // Broadcast a new sync position for this user. -func (s *UserStream) Broadcast(pos types.StreamingToken) { +func (s *UserDeviceStream) Broadcast(pos types.StreamingToken) { s.lock.Lock() defer s.lock.Unlock() @@ -96,7 +98,7 @@ func (s *UserStream) Broadcast(pos types.StreamingToken) { // NumWaiting returns the number of goroutines waiting for waiting for updates. // Used for metrics and testing. -func (s *UserStream) NumWaiting() uint { +func (s *UserDeviceStream) NumWaiting() uint { s.lock.Lock() defer s.lock.Unlock() return s.numWaiting @@ -105,7 +107,7 @@ func (s *UserStream) NumWaiting() uint { // TimeOfLastNonEmpty returns the last time that the number of waiting listeners // was non-empty, may be time.Now() if number of waiting listeners is currently // non-empty. -func (s *UserStream) TimeOfLastNonEmpty() time.Time { +func (s *UserDeviceStream) TimeOfLastNonEmpty() time.Time { s.lock.Lock() defer s.lock.Unlock() @@ -118,7 +120,7 @@ func (s *UserStream) TimeOfLastNonEmpty() time.Time { // GetSyncPosition returns last sync position which the UserStream was // notified about -func (s *UserStreamListener) GetSyncPosition() types.StreamingToken { +func (s *UserDeviceStreamListener) GetSyncPosition() types.StreamingToken { s.userStream.lock.Lock() defer s.userStream.lock.Unlock() @@ -130,7 +132,7 @@ func (s *UserStreamListener) GetSyncPosition() types.StreamingToken { // sincePos specifies from which point we want to be notified about. If there // has already been an update after sincePos we'll return a closed channel // immediately. -func (s *UserStreamListener) GetNotifyChannel(sincePos types.StreamingToken) <-chan struct{} { +func (s *UserDeviceStreamListener) GetNotifyChannel(sincePos types.StreamingToken) <-chan struct{} { s.userStream.lock.Lock() defer s.userStream.lock.Unlock() @@ -147,7 +149,7 @@ func (s *UserStreamListener) GetNotifyChannel(sincePos types.StreamingToken) <-c } // Close cleans up resources used -func (s *UserStreamListener) Close() { +func (s *UserDeviceStreamListener) Close() { s.userStream.lock.Lock() defer s.userStream.lock.Unlock() diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 022ed7bab..762f4e9d3 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -78,8 +78,15 @@ func SetupSyncAPIComponent( base.Cfg, base.KafkaConsumer, notifier, syncDB, ) if err = typingConsumer.Start(); err != nil { - logrus.WithError(err).Panicf("failed to start typing server consumer") + logrus.WithError(err).Panicf("failed to start typing consumer") } - routing.Setup(base.APIMux, requestPool, syncDB, deviceDB, federation, rsAPI, cfg) + sendToDeviceConsumer := consumers.NewOutputSendToDeviceEventConsumer( + base.Cfg, base.KafkaConsumer, notifier, syncDB, + ) + if err = sendToDeviceConsumer.Start(); err != nil { + logrus.WithError(err).Panicf("failed to start send-to-device consumer") + } + + routing.Setup(base.PublicAPIMux, requestPool, syncDB, deviceDB, federation, rsAPI, cfg) } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index caa1b3ade..c1f09fba5 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -296,13 +296,14 @@ type Response struct { Invite map[string]InviteResponse `json:"invite"` Leave map[string]LeaveResponse `json:"leave"` } `json:"rooms"` + ToDevice struct { + Events []gomatrixserverlib.SendToDeviceEvent `json:"events"` + } `json:"to_device"` } // NewResponse creates an empty response with initialised maps. -func NewResponse(token StreamingToken) *Response { - res := Response{ - NextBatch: token.String(), - } +func NewResponse() *Response { + res := Response{} // Pre-initialise the maps. Synapse will return {} even if there are no rooms under a specific section, // so let's do the same thing. Bonus: this means we can't get dreaded 'assignment to entry in nil map' errors. res.Rooms.Join = make(map[string]JoinResponse) @@ -315,6 +316,7 @@ func NewResponse(token StreamingToken) *Response { // This also applies to NewJoinResponse, NewInviteResponse and NewLeaveResponse. res.AccountData.Events = make([]gomatrixserverlib.ClientEvent, 0) res.Presence.Events = make([]gomatrixserverlib.ClientEvent, 0) + res.ToDevice.Events = make([]gomatrixserverlib.SendToDeviceEvent, 0) return &res } @@ -326,7 +328,8 @@ func (r *Response) IsEmpty() bool { len(r.Rooms.Invite) == 0 && len(r.Rooms.Leave) == 0 && len(r.AccountData.Events) == 0 && - len(r.Presence.Events) == 0 + len(r.Presence.Events) == 0 && + len(r.ToDevice.Events) == 0 } // JoinResponse represents a /sync response for a room which is under the 'join' key. @@ -393,3 +396,13 @@ func NewLeaveResponse() *LeaveResponse { res.Timeline.Events = make([]gomatrixserverlib.ClientEvent, 0) return &res } + +type SendToDeviceNID int + +type SendToDeviceEvent struct { + gomatrixserverlib.SendToDeviceEvent + ID SendToDeviceNID + UserID string + DeviceID string + SentByToken *StreamingToken +} diff --git a/sytest-blacklist b/sytest-blacklist index caad25455..1efc207f7 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -39,3 +39,8 @@ Ignore invite in incremental sync # Blacklisted because this test calls /r0/events which we don't implement New room members see their own join event Existing members see new members' join events + +# Blacklisted because the federation work for these hasn't been finished yet. +Can recv device messages over federation +Device messages over federation wake up /sync +Wildcard device messages over federation wake up /sync diff --git a/sytest-whitelist b/sytest-whitelist index 1b12aba1b..6236b28ed 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -288,3 +288,16 @@ New room members see their own join event Existing members see new members' join events Inbound federation can receive events Inbound federation can receive redacted events +Can logout current device +Can send a message directly to a device using PUT /sendToDevice +Can recv a device message using /sync +Can recv device messages until they are acknowledged +Device messages with the same txn_id are deduplicated +Device messages wake up /sync +# TODO: separate PR for: Can recv device messages over federation +# TODO: separate PR for: Device messages over federation wake up /sync +Can send messages with a wildcard device id +Can send messages with a wildcard device id to two devices +Wildcard device messages wake up /sync +# TODO: separate PR for: Wildcard device messages over federation wake up /sync +Can send a to-device message to two users which both receive it using /sync